refactor(renderer): remove legacy aiCore layer

Signed-off-by: suyao <sy20010504@gmail.com>
This commit is contained in:
suyao
2026-04-15 19:54:33 +08:00
parent 182c86139b
commit 188f254786
68 changed files with 5 additions and 17952 deletions

View File

@@ -82,6 +82,7 @@ export interface AiImageRequest extends AiBaseRequest {
numInferenceSteps?: number
guidanceScale?: number
promptEnhancement?: boolean
/** TODO(renderer/aiCore-cleanup): wire personGeneration through to the underlying image runtime once the main image contract formally supports it end-to-end. */
personGeneration?: string
}

View File

@@ -1,223 +0,0 @@
# Cherry Studio AI Provider 技术架构文档 (新方案)
## 1. 核心设计理念与目标
本架构旨在重构 Cherry Studio 的 AI Provider现称为 `aiCore`)层,以实现以下目标:
- **职责清晰**:明确划分各组件的职责,降低耦合度。
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
## 2. 核心组件详解
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
这是整个 AI 功能的核心模块。
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
- 例如SDK 的 `delta.content` -> `TextDeltaChunk`SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
- **关键**`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>``<tool_use>` 标签。
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
#### 2.1.3. `ApiClientFactory.ts`
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
- **职责**:作为所有 AI 相关业务功能的统一入口。
- 提供面向应用的高层接口,例如:
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据实现流式UI更新。
- **封装特定任务的提示工程 (Prompt Engineering)**
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
- **将 `Promise``resolve``reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`
- **优势**
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
- 定义核心的、Provider 无关的内部请求结构,例如:
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
### 2.2. `middleware`
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
**核心组件包括:**
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类用于动态构建中间件链。它支持从基础链开始根据条件添加、插入、替换或移除中间件。
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
#### 2.2.1. `middlewareTypes.ts`
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance``_coreRequest`)、`MiddlewareAPI``CompletionsMiddleware` 等。
#### 2.2.2. 核心中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`
- **`RawSdkChunk`**指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`Gemini 的 `GenerateContentResponse` 中的部分等)。
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`
#### 2.2.3. 特性中间件 (`middleware/feat/`)
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk``ThinkingCompleteChunk`
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
#### 2.2.5. 通用中间件 (`middleware/common/`)
- **`LoggingMiddleware.ts`**: 请求和响应日志。
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
- 在流结束时,发送包含最终累加信息的完成信号。
### 2.3. `types/chunk.ts`
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
```markdown
**应用层 (例如 UI 组件)**
||
\\/
**`AiProvider.completions` (`aiCore/index.ts`)**
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
||
\\/
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
(接收构建好的链、ApiClient实例、原始SDK方法开始按序执行中间件)
||
\\/
**[ 预处理阶段中间件 ]**
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|| (Context 中准备好 SDK 请求参数)
\\/
**[ 处理阶段中间件 ]**
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|| (处理各种特性和Chunk类型)
\\/
**[ SDK调用阶段中间件 ]**
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|| (输出: 标准化的应用层Chunk流)
\\/
**`FinalChunkConsumerMiddleware` (核心)**
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
||
\\/
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
```
## 4. 建议的文件/目录结构
```
src/renderer/src/
└── aiCore/
├── clients/
│ ├── openai/
│ ├── gemini/
│ ├── anthropic/
│ ├── BaseApiClient.ts
│ ├── ApiClientFactory.ts
│ ├── AihubmixAPIClient.ts
│ ├── index.ts
│ └── types.ts
├── middleware/
│ ├── common/
│ ├── core/
│ ├── feat/
│ ├── builder.ts
│ ├── composer.ts
│ ├── index.ts
│ ├── register.ts
│ ├── schemas.ts
│ ├── types.ts
│ └── utils.ts
├── types/
│ ├── chunk.ts
│ └── ...
└── index.ts
```
## 5. 迁移和实施建议
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
- **优先定义核心类型**`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
## 6. 迁移策略与实施建议
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
**目标架构核心组件回顾:**
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
**迁移步骤:**
**Phase 0: 准备工作和类型定义**
1. **定义核心数据结构 (TypeScript 类型)**
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`)定义所有可能的通用Chunk类型。
- 为其他API翻译、总结定义类似的 `CoreXxxRequest` (Type)。
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
3. **调整 `AiProviderMiddlewareContext`**
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`
- 确保包含 `_coreRequest: CoreRequestType`
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void``rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
2. **迁移SDK实例和配置。**
3. **实现 `getRequestTransformer()`**`CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
4. **实现 `getResponseChunkTransformer()`**`OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `

View File

@@ -1,515 +0,0 @@
import { createExecutor } from '@cherrystudio/ai-core'
import type { generateImageResult } from '@cherrystudio/ai-core/core/runtime/types'
import { cacheService } from '@data/CacheService'
import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import { normalizeGatewayModels } from '@renderer/services/models/ModelAdapter'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import {
type Assistant,
type EditImageParams,
type GenerateImageParams,
type Model,
type Provider,
SystemProviderIds
} from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
import { gateway } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import { buildPlugins } from './plugins/PluginBuilder'
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
import { listModels } from './services/listModels'
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
const logger = loggerService.withContext('AiProvider')
export type AiProviderConfig = AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
export default class AiProvider {
private config?: ProviderConfig
private actualProvider: Provider
private model?: Model
/**
* Constructor for AiProvider
*
* @param modelOrProvider - Model or Provider object
* @param provider - Optional Provider object (only used when first param is Model)
*
* @remarks
* **Important behavior notes**:
*
* 1. When called with `(model)`:
* - Calls `getActualProvider(model)` to retrieve and format the provider
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
*
* 2. When called with `(model, provider)`:
* - The provided provider will be adapted via `adaptProvider`
* - URL formatting behavior depends on the adapted result
*
* 3. When called with `(provider)`:
* - The provider will be adapted via `adaptProvider`
* - Used for operations that don't need a model (e.g., fetchModels)
*
* @example
* ```typescript
* // Recommended: Auto-format URL
* const ai = new AiProvider(model)
*
* // Provider will be adapted
* const ai = new AiProvider(model, customProvider)
*
* // For operations that don't need a model
* const ai = new AiProvider(provider)
* ```
*/
constructor(model: Model, provider?: Provider)
constructor(provider: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
if (this.isModel(modelOrProvider)) {
// 传入的是 Model
this.model = modelOrProvider
this.actualProvider = provider
? adaptProvider({ provider, model: modelOrProvider })
: getActualProvider(modelOrProvider)
// 注意config 可能是同步值或 Promise在 completions() 中会统一处理
const configOrPromise = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
this.config = configOrPromise instanceof Promise ? undefined : configOrPromise
} else {
// 传入的是 Provider
this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选某些操作如fetchModels不需要model
}
}
/**
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
*/
private isModel(obj: Model | Provider): obj is Model {
return 'provider' in obj && typeof obj.provider === 'string'
}
public getActualProvider() {
return this.actualProvider
}
public async completions(modelId: string, params: StreamTextParams, middlewareConfig: AiProviderConfig) {
// 检查model是否存在
if (!this.model) {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
// Config is now set in constructor, ApiService handles key rotation before passing provider
if (!this.config) {
// If config wasn't set in constructor (when provider only), generate it now
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, this.model))
}
logger.debug('Using provider config for completions', this.config)
// 注意:模型对象将由 createExecutor 内部处理,不再需要预先创建
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
// 类型守卫:确保 system 是 string、Array 或 undefined
const system = params.system
let systemParam: string | Array<any> | undefined
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
systemParam = system
} else {
// SystemModelMessage 类型,转换为 string
systemParam = undefined
}
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
params.system = undefined // 清除原有system避免重复
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
}
if (middlewareConfig.topicId && (await preferenceService.get('app.developer_mode.enabled'))) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...middlewareConfig,
topicId: middlewareConfig.topicId
}
return await this._completionsForTrace(modelId, params, traceConfig, this.config)
} else {
return await this.modernCompletions(modelId, params, middlewareConfig, this.config)
}
}
/**
* 带trace支持的completions方法
* 类似于legacy的completionsForTrace确保AI SDK spans在正确的trace上下文中
*/
private async _completionsForTrace(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiProviderConfig & { topicId: string },
providerConfig: ProviderConfig
): Promise<CompletionsResult> {
const traceName = `${this.actualProvider.name}.${modelId}.${middlewareConfig.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: middlewareConfig.topicId,
modelName: middlewareConfig.assistant.model?.name, // 使用modelId而不是provider名称
inputs: params
}
logger.info('Starting AI SDK trace span', {
traceName,
topicId: middlewareConfig.topicId,
modelId,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : []
})
const span = await addSpan(traceParams)
if (!span) {
logger.warn('Failed to create span, falling back to regular completions', {
topicId: middlewareConfig.topicId,
modelId,
traceName
})
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
}
try {
logger.info('Created parent span, now calling completions', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId,
parentSpanCreated: true
})
const result = await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId,
resultLength: result.getText().length
})
// 标记span完成
endSpan({
topicId: middlewareConfig.topicId,
outputs: result,
span,
modelName: modelId // 使用modelId保持一致性
})
return result
} catch (error) {
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId
})
// 标记span出错
endSpan({
topicId: middlewareConfig.topicId,
error: error as Error,
span,
modelName: modelId // 使用modelId保持一致性
})
throw error
}
}
/**
* 使用现代化AI SDK的completions实现
*/
/**
* Note: This implementation always uses `executor.streamText` and never
* calls `generateText`, even when `onChunk` is not provided.
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiProviderConfig,
providerConfig: ProviderConfig
): Promise<CompletionsResult> {
const plugins = await buildPlugins({
provider: this.actualProvider,
model: this.model!,
config: middlewareConfig
})
// 用构建好的插件数组创建executor
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
plugins
)
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) {
const accumulate = this.model!.supported_text_delta !== false // true and undefined
const adapter = new AiSdkToChunkAdapter(
middlewareConfig.onChunk,
middlewareConfig.mcpTools,
accumulate,
middlewareConfig.enableWebSearch,
undefined,
undefined,
providerConfig.providerId
)
const streamResult = await executor.streamText({
...params,
model: modelId,
experimental_context: { onChunk: middlewareConfig.onChunk }
})
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else {
// Since no onChunk is provided, the external consumer would not handle error chunk.
// So we need to capture the actual stream error so we can throw it instead of the
// generic NoTextGeneratedError ("No output generated. Check the stream
// for errors.") that AI SDK raises when streamResult.text is accessed
// after a failed stream.
let streamError: unknown = undefined
const streamResult = await executor.streamText({
...params,
model: modelId,
onError({ error }) {
streamError = error
}
})
// 强制消费流,不然await streamResult.text会阻塞
await streamResult?.consumeStream({
onError(error) {
if (!streamError) {
streamError = error
}
}
})
try {
const finalText = await streamResult.text
const usage = await streamResult.totalUsage
return {
getText: () => finalText,
usage
}
} catch (error) {
// If we captured the real stream error, throw that instead of the
// generic NoTextGeneratedError so callers get actionable diagnostics.
if (streamError) {
throw streamError
}
throw error
}
}
}
/**
* 获取模型列表
* 使用 ModelListService 统一处理各 Provider 的模型列表获取
*/
public async models(): Promise<Model[]> {
// Gateway provider 使用 AI SDK 的 gateway API
if (this.actualProvider.id === SystemProviderIds.gateway) {
const gatewayModels = (await gateway.getAvailableModels()).models
return normalizeGatewayModels(this.actualProvider, gatewayModels)
}
// 使用新的 ModelListService
return await listModels(this.actualProvider)
}
/**
* 获取嵌入模型的维度
* 使用 AI SDK embedMany 测试获取维度
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
// 确保 config 已定义
if (!this.config) {
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
}
const executor = await createExecutor<AppProviderSettingsMap>(
this.config.providerId,
this.config.providerSettings,
[]
)
// 使用 AI SDK embedMany 测试获取维度
const result = await executor.embedMany({
model: model.id,
values: ['test']
})
return result.embeddings[0].length
}
/**
* 懒加载初始化 config
* 当 constructor 只传入 provider 时config 不会被初始化
* 此方法根据 modelId 从 provider 的 models 中查找真实 Model 并生成 config
*/
private async ensureConfig(modelId: string): Promise<void> {
if (this.config) {
return
}
// 从 provider 的 models 中查找真实的 model
const model = this.actualProvider.models.find((m) => getLowerBaseModelName(m.id) === getLowerBaseModelName(modelId))
if (!model) {
throw new Error(`Model "${modelId}" not found in provider "${this.actualProvider.id}"`)
}
this.actualProvider = adaptProvider({ provider: this.actualProvider, model })
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
}
/**
* 生成图像
* 使用现代化 AI SDK 实现,不再 fallback 到 legacy
*/
public async generateImage(params: GenerateImageParams): Promise<string[]> {
await this.ensureConfig(params.model)
return await this.modernGenerateImage(params, this.config!)
}
/**
* 编辑图像 - 基于输入图像和文本提示生成新图像
* 内部使用 AI SDK 的 generateImage通过 prompt.images 参数实现编辑功能
*/
public async editImage(params: EditImageParams): Promise<string[]> {
await this.ensureConfig(params.model)
return await this.modernEditImage(params, this.config!)
}
/**
* 使用现代化 AI SDK 的图像生成实现
*/
private async modernGenerateImage(params: GenerateImageParams, providerConfig: ProviderConfig): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
[]
)
const result = await executor.generateImage({
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
...aiSdkParams
})
return this.convertImageResult(result)
}
/**
* 使用现代化 AI SDK 的图像编辑实现
* 通过 AI SDK 的 generateImage 并传入 prompt.images 参数实现编辑功能
*/
private async modernEditImage(params: EditImageParams, providerConfig: ProviderConfig): Promise<string[]> {
const { model, prompt, inputImages, mask, imageSize, signal } = params
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
[]
)
// 使用 AI SDK 的 generateImage通过 prompt.images 实现编辑
const result = await executor.generateImage({
model: model,
prompt: {
text: prompt,
images: inputImages, // 输入图像(必需)
...(mask && { mask }) // 可选的 mask用于 inpainting
},
size: (imageSize || '1024x1024') as `${number}x${number}`,
...(signal && { abortSignal: signal })
})
return this.convertImageResult(result)
}
/**
* 转换图像生成结果格式
*/
private convertImageResult(result: generateImageResult): string[] {
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if (image.base64) {
images.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
}
}
}
return images
}
public getBaseURL(): string {
return this.actualProvider.apiHost || ''
}
public getApiKey(): string {
const apiKey = this.actualProvider.apiKey
if (!apiKey || apiKey.trim() === '') {
return ''
}
const keys = apiKey
.split(',')
.map((key) => key.trim())
.filter(Boolean)
if (keys.length === 0) {
return ''
}
if (keys.length === 1) {
return keys[0]
}
// Multi-key rotation
const keyName = `provider:${this.actualProvider.id}:last_used_key`
const lastUsedKey = cacheService.getCasual<string>(keyName)
if (!lastUsedKey) {
cacheService.setCasual(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
cacheService.setCasual(keyName, nextKey)
return nextKey
}
}

View File

@@ -1,475 +0,0 @@
/**
* AI SDK 到 Cherry Studio Chunk 适配器
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
*/
import { loggerService } from '@logger'
import type { AISDKWebSearchResult, MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { WEB_SEARCH_SOURCE } from '@renderer/types'
import type { Chunk, ProviderMetadata } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { ProviderSpecificError } from '@renderer/types/provider-specific-error'
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
import type { IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
import type { ClaudeCodeRawValue } from '@shared/agents/claudecode/types'
import { AISDKError, type TextStreamPart, type ToolSet } from 'ai'
import { ToolCallChunkHandler } from './handleToolCallChunk'
const logger = loggerService.withContext('AiSdkToChunkAdapter')
/**
* AI SDK 到 Cherry Studio Chunk 适配器类
* 处理 fullStream 到 Cherry Studio chunk 的转换
*/
export class AiSdkToChunkAdapter {
toolCallHandler: ToolCallChunkHandler
private accumulate: boolean | undefined
private isFirstChunk = true
private enableWebSearch: boolean = false
private onSessionUpdate?: (sessionId: string) => void
private responseStartTimestamp: number | null = null
private firstTokenTimestamp: number | null = null
private hasTextContent = false
private getSessionWasCleared?: () => boolean
private providerId?: string
private idleTimeout?: IdleTimeoutHandle
constructor(
private onChunk: (chunk: Chunk) => void,
mcpTools: MCPTool[] = [],
accumulate?: boolean,
enableWebSearch?: boolean,
onSessionUpdate?: (sessionId: string) => void,
getSessionWasCleared?: () => boolean,
providerId?: string,
idleTimeout?: IdleTimeoutHandle
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
this.accumulate = accumulate
this.enableWebSearch = enableWebSearch || false
this.onSessionUpdate = onSessionUpdate
this.getSessionWasCleared = getSessionWasCleared
this.providerId = providerId
this.idleTimeout = idleTimeout
}
private markFirstTokenIfNeeded() {
if (this.firstTokenTimestamp === null && this.responseStartTimestamp !== null) {
this.firstTokenTimestamp = Date.now()
}
}
private resetTimingState() {
this.responseStartTimestamp = null
this.firstTokenTimestamp = null
}
/**
* 处理 AI SDK 流结果
* @param aiSdkResult AI SDK 的流结果对象
* @returns 最终的文本内容
*/
async processStream(aiSdkResult: any): Promise<string> {
// The stream is the single source of truth for abort handling.
// Both AI SDK (resilient stream) and the agent pipeline (withAbortStreamPart)
// guarantee: abort → enqueue { type: 'abort' } → close gracefully.
// convertAndEmitChunk processes the abort part and emits ChunkType.ERROR → onError.
if (aiSdkResult.fullStream) {
await this.readFullStream(aiSdkResult.fullStream)
}
try {
return await aiSdkResult.text
} catch (error: any) {
// The text promise rejects when no steps completed (e.g. abort during thinking).
// The abort was already handled via the 'abort' stream part above.
if (isAbortError(error)) {
return ''
}
throw error
}
}
/**
* 读取 fullStream 并转换为 Cherry Studio chunks
* @param fullStream AI SDK 的 fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
const reader = fullStream.getReader()
const final = {
text: '',
reasoningContent: '',
webSearchResults: [],
reasoningId: '',
providerMetadata: undefined as ProviderMetadata | undefined
}
this.resetTimingState()
this.responseStartTimestamp = Date.now()
// Reset state at the start of stream
this.isFirstChunk = true
this.hasTextContent = false
try {
while (true) {
const { done, value } = await reader.read()
// Reset idle timeout on every chunk received from the stream
this.idleTimeout?.reset()
if (done) {
// Flush any remaining content from link converter buffer if web search is enabled
if (this.enableWebSearch) {
const remainingText = flushLinkConverterBuffer()
if (remainingText) {
this.markFirstTokenIfNeeded()
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: remainingText
})
}
}
break
}
// 转换并发送 chunk
this.convertAndEmitChunk(value, final)
}
} finally {
reader.releaseLock()
this.resetTimingState()
// Clean up the idle timeout timer when the stream ends
this.idleTimeout?.cleanup()
}
}
/**
* 如果有累积的思考内容,发送 THINKING_COMPLETE chunk 并清空
* @param final 包含 reasoningContent 的状态对象
* @returns 是否发送了 THINKING_COMPLETE chunk
*/
private emitThinkingCompleteIfNeeded(final: { reasoningContent: string; [key: string]: any }) {
if (final.reasoningContent) {
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: final.reasoningContent
})
final.reasoningContent = ''
}
}
/**
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
* @param chunk AI SDK 的 chunk 数据
*/
private convertAndEmitChunk(
chunk: TextStreamPart<any>,
final: {
text: string
reasoningContent: string
webSearchResults: AISDKWebSearchResult[]
reasoningId: string
providerMetadata: ProviderMetadata | undefined
}
) {
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
switch (chunk.type) {
case 'raw': {
const agentRawMessage = chunk.rawValue as ClaudeCodeRawValue
if (agentRawMessage.type === 'init' && agentRawMessage.session_id) {
this.onSessionUpdate?.(agentRawMessage.session_id)
} else if (agentRawMessage.type === 'compact' && agentRawMessage.session_id) {
this.onSessionUpdate?.(agentRawMessage.session_id)
}
this.onChunk({
type: ChunkType.RAW,
content: agentRawMessage
})
break
}
// === 文本相关事件 ===
case 'text-start':
// 如果有未完成的思考内容,先生成 THINKING_COMPLETE
// 这处理了某些提供商不发送 reasoning-end 事件的情况
this.emitThinkingCompleteIfNeeded(final)
this.onChunk({
type: ChunkType.TEXT_START
})
break
case 'text-delta': {
this.hasTextContent = true
const processedText = chunk.text || ''
let finalText: string
// Only apply link conversion if web search is enabled
if (this.enableWebSearch) {
const result = convertLinks(processedText, this.isFirstChunk)
if (this.isFirstChunk) {
this.isFirstChunk = false
}
// Handle buffered content
if (result.hasBufferedContent) {
finalText = result.text
} else {
finalText = result.text || processedText
}
} else {
// Without web search, just use the original text
finalText = processedText
}
if (this.accumulate) {
final.text += finalText
} else {
final.text = finalText
}
// Extract thoughtSignature from providerMetadata.google and preserve it
const newSignature = chunk.providerMetadata?.google?.thoughtSignature as string | undefined
if (newSignature) {
final.providerMetadata = {
...final.providerMetadata,
google: {
...final.providerMetadata?.google,
thoughtSignature: newSignature
}
}
}
// Only emit chunk if there's text to send
if (finalText) {
this.markFirstTokenIfNeeded()
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: this.accumulate ? final.text : finalText,
providerMetadata: final.providerMetadata
})
}
break
}
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? '',
providerMetadata: final.providerMetadata
})
final.text = ''
// Clear providerMetadata for next text block
final.providerMetadata = undefined
break
case 'reasoning-start':
// if (final.reasoningId !== chunk.id) {
final.reasoningId = chunk.id
this.onChunk({
type: ChunkType.THINKING_START
})
// }
break
case 'reasoning-delta':
final.reasoningContent += chunk.text || ''
if (chunk.text) {
this.markFirstTokenIfNeeded()
}
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: final.reasoningContent || ''
})
break
case 'reasoning-end':
this.emitThinkingCompleteIfNeeded(final)
break
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
case 'tool-input-start':
this.toolCallHandler.handleToolInputStart(chunk)
break
case 'tool-input-delta':
this.toolCallHandler.handleToolInputDelta(chunk)
break
case 'tool-input-end':
this.toolCallHandler.handleToolInputEnd(chunk)
break
case 'tool-call':
this.toolCallHandler.handleToolCall(chunk)
break
case 'tool-error':
this.toolCallHandler.handleToolError(chunk)
break
case 'tool-result':
this.toolCallHandler.handleToolResult(chunk)
break
case 'finish-step': {
const { providerMetadata, finishReason } = chunk
// googel web search
if (providerMetadata?.google?.groundingMetadata) {
this.onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
source: WEB_SEARCH_SOURCE.GEMINI
}
})
} else if (final.webSearchResults.length) {
const providerName: string | undefined = Object.keys(providerMetadata || {})[0] || this.providerId
const sourceMap: Record<string, WebSearchSource> = {
[WEB_SEARCH_SOURCE.OPENAI]: WEB_SEARCH_SOURCE.OPENAI_RESPONSE,
[WEB_SEARCH_SOURCE.ANTHROPIC]: WEB_SEARCH_SOURCE.ANTHROPIC,
[WEB_SEARCH_SOURCE.OPENROUTER]: WEB_SEARCH_SOURCE.OPENROUTER,
[WEB_SEARCH_SOURCE.GEMINI]: WEB_SEARCH_SOURCE.GEMINI,
// [WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
[WEB_SEARCH_SOURCE.QWEN]: WEB_SEARCH_SOURCE.QWEN,
[WEB_SEARCH_SOURCE.HUNYUAN]: WEB_SEARCH_SOURCE.HUNYUAN,
[WEB_SEARCH_SOURCE.ZHIPU]: WEB_SEARCH_SOURCE.ZHIPU,
[WEB_SEARCH_SOURCE.GROK]: WEB_SEARCH_SOURCE.GROK,
xai: WEB_SEARCH_SOURCE.GROK,
[WEB_SEARCH_SOURCE.WEBSEARCH]: WEB_SEARCH_SOURCE.WEBSEARCH
}
const source = (providerName && sourceMap[providerName]) || WEB_SEARCH_SOURCE.AISDK
this.onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: final.webSearchResults,
source
}
})
}
if (finishReason === 'tool-calls') {
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
}
final.webSearchResults = []
// final.reasoningId = ''
break
}
case 'finish': {
// Check if session was cleared (e.g., /clear command) and no text was output
const sessionCleared = this.getSessionWasCleared?.() ?? false
if (sessionCleared && !this.hasTextContent) {
// Inject a "context cleared" message for the user
const clearMessage = '✨ Context cleared. Starting fresh conversation.'
this.onChunk({
type: ChunkType.TEXT_START
})
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: clearMessage
})
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: clearMessage
})
final.text = clearMessage
}
const usage = {
completion_tokens: chunk.totalUsage?.outputTokens || 0,
prompt_tokens: chunk.totalUsage?.inputTokens || 0,
total_tokens: chunk.totalUsage?.totalTokens || 0
}
const metrics = this.buildMetrics(chunk.totalUsage)
const baseResponse = {
text: final.text || '',
reasoning_content: final.reasoningContent || ''
}
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
...baseResponse,
usage: { ...usage },
metrics: metrics ? { ...metrics } : undefined
}
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
...baseResponse,
usage: { ...usage },
metrics: metrics ? { ...metrics } : undefined
}
})
this.resetTimingState()
break
}
// === 源和文件相关事件 ===
case 'source':
if (chunk.sourceType === 'url') {
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
const { sourceType: _, ...rest } = chunk
final.webSearchResults.push(rest)
}
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
}
})
break
case 'abort':
this.onChunk({
type: ChunkType.ERROR,
error: new DOMException('Request was aborted', 'AbortError')
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error: AISDKError.isInstance(chunk.error)
? chunk.error
: new ProviderSpecificError({
message: formatErrorMessage(chunk.error),
provider: 'unknown',
cause: chunk.error
})
})
break
default:
}
}
private buildMetrics(totalUsage?: {
inputTokens?: number | null
outputTokens?: number | null
totalTokens?: number | null
}) {
if (!totalUsage) {
return undefined
}
const completionTokens = totalUsage.outputTokens ?? 0
const now = Date.now()
const start = this.responseStartTimestamp ?? now
const firstToken = this.firstTokenTimestamp
const timeFirstToken = Math.max(firstToken != null ? firstToken - start : 0, 0)
const baseForCompletion = firstToken ?? start
let timeCompletion = Math.max(now - baseForCompletion, 0)
if (timeCompletion === 0 && completionTokens > 0) {
timeCompletion = 1
}
return {
completion_tokens: completionTokens,
time_first_token_millsec: timeFirstToken,
time_completion_millsec: timeCompletion
}
}
}
export default AiSdkToChunkAdapter

View File

@@ -1,435 +0,0 @@
/**
* 工具调用 Chunk 处理模块
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
* 提供工具调用相关的处理API每个交互使用一个新的实例
*/
import { loggerService } from '@logger'
import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
import type { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
const logger = loggerService.withContext('ToolCallChunkHandler')
export type ToolcallsMap = {
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
// Streaming arguments buffer
streamingArgs?: string
}
/**
* 工具调用处理器类
*/
export class ToolCallChunkHandler {
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
) {}
/**
* 内部静态方法:添加活跃工具调用的核心逻辑
*/
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
return true
}
return false
}
/**
* 实例方法:添加活跃工具调用
*/
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
/**
* 获取全局活跃的工具调用
*/
public static getActiveToolCalls() {
return ToolCallChunkHandler.globalActiveToolCalls
}
/**
* 静态方法:添加活跃工具调用(外部访问)
*/
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
/**
* 根据工具名称确定工具类型
*/
private determineToolType(toolName: string, toolCallId: string): BaseTool {
let mcpTool: MCPTool | undefined
if (toolName.startsWith('builtin_')) {
return {
id: toolCallId,
name: toolName,
description: toolName,
type: 'builtin'
} as BaseTool
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
return mcpTool
} else {
return {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
}
}
}
/**
* 处理工具输入开始事件 - 流式参数开始
*/
public handleToolInputStart(chunk: {
type: 'tool-input-start'
id: string
toolName: string
providerMetadata?: ProviderMetadata
providerExecuted?: boolean
}): void {
const { id: toolCallId, toolName, providerExecuted } = chunk
if (!toolCallId || !toolName) {
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool-input-start chunk: missing id or toolName`)
return
}
// 如果已存在,跳过
if (this.activeToolCalls.has(toolCallId)) {
return
}
let tool: BaseTool
if (providerExecuted) {
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
} as BaseTool
} else {
tool = this.determineToolType(toolName, toolCallId)
}
// 初始化流式工具调用
this.addActiveToolCall(toolCallId, {
toolCallId,
toolName,
args: undefined,
tool,
streamingArgs: ''
})
logger.info(`🔧 [ToolCallChunkHandler] Tool input streaming started: ${toolName} (${toolCallId})`)
// 发送初始 streaming chunk
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: tool,
arguments: undefined,
status: 'streaming',
toolCallId: toolCallId,
partialArguments: ''
}
this.onChunk({
type: ChunkType.MCP_TOOL_STREAMING,
responses: [toolResponse]
})
}
/**
* 处理工具输入增量事件 - 流式参数片段
*/
public handleToolInputDelta(chunk: {
type: 'tool-input-delta'
id: string
delta: string
providerMetadata?: ProviderMetadata
}): void {
const { id: toolCallId, delta } = chunk
const toolCall = this.activeToolCalls.get(toolCallId)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found for delta: ${toolCallId}`)
return
}
// 累积流式参数
toolCall.streamingArgs = (toolCall.streamingArgs || '') + delta
// 发送 streaming chunk 更新
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCall.tool,
arguments: undefined,
status: 'streaming',
toolCallId: toolCallId,
partialArguments: toolCall.streamingArgs
}
this.onChunk({
type: ChunkType.MCP_TOOL_STREAMING,
responses: [toolResponse]
})
}
/**
* 处理工具输入结束事件 - 流式参数完成
*/
public handleToolInputEnd(chunk: { type: 'tool-input-end'; id: string; providerMetadata?: ProviderMetadata }): void {
const { id: toolCallId } = chunk
const toolCall = this.activeToolCalls.get(toolCallId)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found for end: ${toolCallId}`)
return
}
// 尝试解析完整的 JSON 参数
let parsedArgs: any = undefined
if (toolCall.streamingArgs) {
try {
parsedArgs = JSON.parse(toolCall.streamingArgs)
toolCall.args = parsedArgs
} catch (e) {
logger.warn(`🔧 [ToolCallChunkHandler] Failed to parse streaming args for ${toolCallId}:`, e as Error)
// 保留原始字符串
toolCall.args = toolCall.streamingArgs
}
}
logger.info(`🔧 [ToolCallChunkHandler] Tool input streaming completed: ${toolCall.toolName} (${toolCallId})`)
// 发送 streaming 完成 chunk
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCall.tool,
arguments: parsedArgs,
status: 'pending',
toolCallId: toolCallId,
partialArguments: toolCall.streamingArgs
}
this.onChunk({
type: ChunkType.MCP_TOOL_STREAMING,
responses: [toolResponse]
})
}
/**
* 处理工具调用事件
*/
public handleToolCall(
chunk: {
type: 'tool-call'
} & TypedToolCall<ToolSet>
): void {
const { toolCallId, toolName, input: args, providerExecuted } = chunk
if (!toolCallId || !toolName) {
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
return
}
// Check if this tool call was already processed via streaming events
const existingToolCall = this.activeToolCalls.get(toolCallId)
if (existingToolCall?.streamingArgs !== undefined) {
// Tool call was already processed via streaming events (tool-input-start/delta/end)
// Update args if needed, but don't emit duplicate pending chunk
existingToolCall.args = args
return
}
let tool: BaseTool
let mcpTool: MCPTool | undefined
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
} as BaseTool
} else if (toolName.startsWith('builtin_')) {
// 如果是内置工具,沿用现有逻辑
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'builtin'
} as BaseTool
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
// toolName is mcpTool.id (registered with id as key in convertMcpToolsToAiSdkTools)
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
tool = mcpTool
} else {
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
}
}
this.addActiveToolCall(toolCallId, {
toolCallId,
toolName,
args,
tool
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: tool,
arguments: args,
status: 'pending', // 统一使用 pending 状态
toolCallId: toolCallId
}
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
responses: [toolResponse]
})
}
}
/**
* 处理工具调用结果事件
*/
public handleToolResult(
chunk: {
type: 'tool-result'
} & TypedToolResult<ToolSet>
): void {
// TODO: 基于AI SDK为供应商内置工具做更好的展示和类型安全处理
const { toolCallId, output, input } = chunk
if (!toolCallId) {
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
return
}
// 查找对应的工具调用信息
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
// 创建工具调用结果的 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallInfo.toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'done',
response: output,
toolCallId: toolCallId
}
// 工具特定的后处理
switch (toolResponse.tool.name) {
case 'builtin_knowledge_search': {
processKnowledgeReferences(toolResponse.response, this.onChunk)
break
}
// 未来可以在这里添加其他工具的后处理逻辑
default:
break
}
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
this.activeToolCalls.delete(toolCallId)
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
const images = extractImagesFromToolOutput(toolResponse.response)
if (images.length) {
this.onChunk({
type: ChunkType.IMAGE_CREATED
})
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: images
}
})
}
}
}
handleToolError(
chunk: {
type: 'tool-error'
} & TypedToolError<ToolSet>
): void {
const { toolCallId, error, input } = chunk
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'error',
response: error,
toolCallId: toolCallId
}
this.activeToolCalls.delete(toolCallId)
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
}
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)
/**
* 从工具输出中提取图片(使用 MCP SDK 类型安全验证)
*/
function extractImagesFromToolOutput(output: unknown): string[] {
if (!output) {
return []
}
const result = CallToolResultSchema.safeParse(output)
if (result.success) {
return result.data.content
.filter((c) => c.type === 'image')
.map((content) => `data:${content.mimeType ?? 'image/png'};base64,${content.data}`)
}
return []
}

View File

@@ -1 +0,0 @@
export { default as AiProvider, type AiProviderConfig } from './AiProvider'

View File

@@ -1,140 +0,0 @@
import type { AiPlugin } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, providerToolPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import { isGemini3Model, isQwen35to39Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import type { Assistant, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { AiSdkMiddlewareConfig } from '../types/middlewareConfig'
import { getReasoningTagName } from '../utils/reasoning'
import { createAnthropicCachePlugin } from './anthropicCachePlugin'
import { createNoThinkPlugin } from './noThinkPlugin'
import { createOpenrouterReasoningPlugin } from './openrouterReasoningPlugin'
import { createPdfCompatibilityPlugin } from './pdfCompatibilityPlugin'
import { createQwenThinkingPlugin } from './qwenThinkingPlugin'
import { createReasoningExtractionPlugin } from './reasoningExtractionPlugin'
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
import { createSimulateStreamingPlugin } from './simulateStreamingPlugin'
import { createSkipGeminiThoughtSignaturePlugin } from './skipGeminiThoughtSignaturePlugin'
import { createTelemetryPlugin } from './telemetryPlugin'
const logger = loggerService.withContext('PluginBuilder')
/**
* 构建插件的上下文参数
*
* provider 和 model 是必选的 — 由 AiProvider 内部注入,
* 不再依赖调用方手动传入,从根本上避免遗漏。
*/
export interface BuildPluginsContext {
provider: Provider
model: Model
config: AiSdkMiddlewareConfig & { assistant: Assistant; topicId?: string }
}
/**
* 根据条件构建插件数组
*/
export async function buildPlugins({ provider, model, config }: BuildPluginsContext): Promise<AiPlugin[]> {
const plugins: AiPlugin<any, any>[] = []
if (config.topicId && (await preferenceService.get('app.developer_mode.enabled'))) {
// 0. 添加 telemetry 插件
plugins.push(
createTelemetryPlugin({
enabled: true,
topicId: config.topicId,
assistant: config.assistant
})
)
}
// === PDF Compatibility ===
// Must run before other plugins (e.g., Anthropic cache token estimation)
// so that PDF FileParts are converted to TextParts for unsupported providers.
plugins.push(createPdfCompatibilityPlugin(provider, model))
// === AI SDK Middleware Plugins ===
// 注意wrapLanguageModel 会 .reverse() middleware 数组,
// 数组中靠前的 middleware 反转后变成最外层包装。
// extractReasoning 必须在 simulateStreaming 之前推入,
// 这样反转后 extractReasoning 在外层,其 wrapStream状态机
// 能处理 simulateStreaming 生成的模拟流中的未闭合 <think> 标签。
// 0.1 Reasoning extraction for OpenAI/Azure providers
const providerType = provider.type
if (providerType === 'openai' || providerType === 'azure-openai') {
const tagName = getReasoningTagName(model.id.toLowerCase())
plugins.push(createReasoningExtractionPlugin({ tagName }))
}
// 0.2 Simulate streaming for non-streaming requests (must be AFTER reasoning extraction in array)
if (!config.streamOutput) {
plugins.push(createSimulateStreamingPlugin())
}
if (provider.anthropicCacheControl?.tokenThreshold) {
plugins.push(createAnthropicCachePlugin(provider))
}
// 0.3 OpenRouter reasoning redaction
if (provider.id === SystemProviderIds.openrouter) {
plugins.push(createOpenrouterReasoningPlugin())
}
// 0.4 OVMS no-think for MCP tools
if (provider.id === 'ovms' && config.mcpTools && config.mcpTools.length > 0) {
plugins.push(createNoThinkPlugin())
}
// 0.5 Qwen thinking control for providers without enable_thinking support
if (
!isOllamaProvider(provider) &&
isSupportedThinkingTokenQwenModel(model) &&
!isQwen35to39Model(model) &&
!isSupportEnableThinkingProvider(provider)
) {
const enableThinking = config.assistant?.settings?.reasoning_effort !== undefined
plugins.push(createQwenThinkingPlugin(enableThinking))
}
// 0.6 Skip Gemini3 thought signature for OpenAI-compatible API
if (isGemini3Model(model)) {
plugins.push(createSkipGeminiThoughtSignaturePlugin())
}
// 1. Provider 工具注入 — providerToolPlugin 自动按 provider 分发工具
if (config.enableWebSearch && config.webSearchPluginConfig) {
plugins.push(providerToolPlugin('webSearch', config.webSearchPluginConfig))
}
if (config.enableUrlContext) {
plugins.push(providerToolPlugin('urlContext', config.urlContextConfig))
}
// 2. 支持工具调用时添加搜索插件
if (config.isSupportedToolUse || config.isPromptToolUse) {
plugins.push(searchOrchestrationPlugin(config.assistant, config.topicId || ''))
}
// 3. 推理模型时添加推理插件
// if (config.enableReasoning) {
// plugins.push(reasoningTimePlugin)
// }
// 4. 启用Prompt工具调用时添加工具插件
if (config.isPromptToolUse) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
mcpMode: config.mcpMode
})
)
}
logger.debug(
'Final plugin list:',
plugins.map((p) => p.name)
)
return plugins
}

View File

@@ -1,243 +0,0 @@
import type { LanguageModelV3CallOptions } from '@ai-sdk/provider'
import type { Model, Provider, ProviderType } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('i18next', () => ({
default: { t: (key: string, opts?: Record<string, unknown>) => `${key}${opts ? JSON.stringify(opts) : ''}` }
}))
vi.mock('@renderer/config/models', () => ({
isAnthropicModel: vi.fn(() => false),
isGeminiModel: vi.fn(() => false)
}))
vi.mock('@renderer/config/models/openai', () => ({
isOpenAILLMModel: vi.fn(() => false)
}))
const mockExtractPdfText = vi.fn()
vi.mock('@shared/utils/pdf', () => ({
extractPdfText: (...args: unknown[]) => mockExtractPdfText(...args)
}))
vi.stubGlobal('window', {
...globalThis.window,
api: {
pdf: {
extractText: mockExtractPdfText
}
},
toast: {
warning: vi.fn(),
error: vi.fn()
}
})
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models'
import { isOpenAILLMModel } from '@renderer/config/models/openai'
import { createPdfCompatibilityPlugin } from '../pdfCompatibilityPlugin'
function makeProvider(id: string, type: ProviderType): Provider {
return { id, name: id, type, apiKey: 'test', apiHost: 'https://test.com', isSystem: false, models: [] } as Provider
}
function makeModel(): Model {
return { id: 'test-model', provider: 'test', name: 'Test', group: 'test' } as Model
}
function makePdfFilePart(filename = 'test.pdf') {
return {
type: 'file' as const,
data: 'base64pdfdata',
mediaType: 'application/pdf',
filename
}
}
function makeImageFilePart() {
return {
type: 'file' as const,
data: 'base64imgdata',
mediaType: 'image/png',
filename: 'test.png'
}
}
function makeTextPart(text: string) {
return { type: 'text' as const, text }
}
async function runMiddleware(provider: Provider, params: LanguageModelV3CallOptions, model: Model = makeModel()) {
const plugin = createPdfCompatibilityPlugin(provider, model)
const context: {
middlewares: Array<{ transformParams: (opts: Record<string, unknown>) => Promise<LanguageModelV3CallOptions> }>
} = { middlewares: [] }
void plugin.configureContext!(context as never)
const middleware = context.middlewares[0]
return middleware.transformParams({ params, type: 'generate', model: {} })
}
describe('pdfCompatibilityPlugin', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(isOpenAILLMModel).mockReturnValue(false)
vi.mocked(isAnthropicModel).mockReturnValue(false)
vi.mocked(isGeminiModel).mockReturnValue(false)
})
it('should pass through for OpenAI model on any provider type', async () => {
vi.mocked(isOpenAILLMModel).mockReturnValue(true)
const provider = makeProvider('moonshot', 'openai')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result).toEqual(params)
expect(mockExtractPdfText).not.toHaveBeenCalled()
})
it('should pass through for Claude model on any provider type', async () => {
vi.mocked(isAnthropicModel).mockReturnValue(true)
const provider = makeProvider('my-aggregator', 'new-api')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result).toEqual(params)
expect(mockExtractPdfText).not.toHaveBeenCalled()
})
it('should pass through for Gemini model on any provider type', async () => {
vi.mocked(isGeminiModel).mockReturnValue(true)
const provider = makeProvider('my-aggregator', 'new-api')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result).toEqual(params)
expect(mockExtractPdfText).not.toHaveBeenCalled()
})
it('should pass through unchanged when provider type supports native PDF (openai-response)', async () => {
const provider = makeProvider('openai', 'openai-response')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result).toEqual(params)
expect(mockExtractPdfText).not.toHaveBeenCalled()
})
it('should convert PDF for non-native provider types (new-api, gateway, openai)', async () => {
const provider = makeProvider('moonshot', 'openai')
mockExtractPdfText.mockResolvedValue('Extracted PDF content')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('report.pdf')] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(mockExtractPdfText).toHaveBeenCalledWith('base64pdfdata')
expect(result.prompt[0]).toMatchObject({
role: 'user',
content: [
{ type: 'text', text: 'Hello' },
{ type: 'text', text: 'report.pdf\nExtracted PDF content' }
]
})
})
it('should convert PDF FilePart to TextPart for ollama provider', async () => {
const provider = makeProvider('ollama', 'ollama')
mockExtractPdfText.mockResolvedValue('Extracted PDF content')
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('report.pdf')] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(mockExtractPdfText).toHaveBeenCalledWith('base64pdfdata')
expect(result.prompt[0]).toMatchObject({
role: 'user',
content: [
{ type: 'text', text: 'Hello' },
{ type: 'text', text: 'report.pdf\nExtracted PDF content' }
]
})
})
it('should drop PDF part and warn when text extraction fails', async () => {
const provider = makeProvider('ollama', 'ollama')
mockExtractPdfText.mockRejectedValue(new Error('parse failed'))
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('broken.pdf')] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result.prompt[0]).toMatchObject({
role: 'user',
content: [{ type: 'text', text: 'Hello' }]
})
expect(window.toast.warning).toHaveBeenCalled()
})
it('should not convert non-PDF FileParts', async () => {
const provider = makeProvider('ollama', 'ollama')
const imagePart = makeImageFilePart()
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), imagePart] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result.prompt[0]).toMatchObject({
role: 'user',
content: [{ type: 'text', text: 'Hello' }, imagePart]
})
expect(mockExtractPdfText).not.toHaveBeenCalled()
})
it('should handle mixed content: text + PDF + image — only PDF converted', async () => {
const provider = makeProvider('ollama', 'ollama')
mockExtractPdfText.mockResolvedValue('PDF text content')
const imagePart = makeImageFilePart()
const params = {
prompt: [{ role: 'user' as const, content: [makeTextPart('Analyze'), makePdfFilePart('doc.pdf'), imagePart] }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result.prompt[0]).toMatchObject({
role: 'user',
content: [{ type: 'text', text: 'Analyze' }, { type: 'text', text: 'doc.pdf\nPDF text content' }, imagePart]
})
})
it('should pass through when prompt is empty', async () => {
const provider = makeProvider('ollama', 'ollama')
const params = { prompt: [] } as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result).toEqual(params)
})
it('should pass through messages with string content (system messages)', async () => {
const provider = makeProvider('ollama', 'ollama')
const params = {
prompt: [{ role: 'system' as const, content: 'You are a helpful assistant' }]
} as unknown as LanguageModelV3CallOptions
const result = await runMiddleware(provider, params)
expect(result.prompt[0]).toMatchObject({ role: 'system', content: 'You are a helpful assistant' })
})
})

View File

@@ -1,96 +0,0 @@
/**
* Anthropic Prompt Caching Middleware
* @see https://ai-sdk.dev/providers/ai-sdk-providers/anthropic#cache-control
*/
import type { LanguageModelV3Message } from '@ai-sdk/provider'
import { definePlugin } from '@cherrystudio/ai-core/core/plugins'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type { Provider } from '@renderer/types'
import type { LanguageModelMiddleware } from 'ai'
const cacheProviderOptions = {
anthropic: { cacheControl: { type: 'ephemeral' } }
}
function estimateContentTokens(content: LanguageModelV3Message['content']): number {
if (typeof content === 'string') return estimateTextTokens(content)
if (Array.isArray(content)) {
return content.reduce((acc, part) => {
if (part.type === 'text') {
return acc + estimateTextTokens(part.text)
}
return acc
}, 0)
}
return 0
}
function anthropicCacheMiddleware(provider: Provider): LanguageModelMiddleware {
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
const settings = provider.anthropicCacheControl
if (!settings?.tokenThreshold || !Array.isArray(params.prompt) || params.prompt.length === 0) {
return params
}
const { tokenThreshold, cacheSystemMessage, cacheLastNMessages } = settings
const messages = [...params.prompt]
let cachedCount = 0
// Cache system message (providerOptions on message object)
if (cacheSystemMessage) {
for (let i = 0; i < messages.length; i++) {
const msg = messages[i]
if (msg.role === 'system' && estimateContentTokens(msg.content) >= tokenThreshold) {
messages[i] = { ...msg, providerOptions: cacheProviderOptions }
break
}
}
}
// Cache last N non-system messages (providerOptions on content parts)
if (cacheLastNMessages > 0) {
const cumsumTokens = [] as Array<number>
let tokenSum = 0 as number
for (let i = 0; i < messages.length; i++) {
const msg = messages[i]
tokenSum += estimateContentTokens(msg.content)
cumsumTokens.push(tokenSum)
}
for (let i = messages.length - 1; i >= 0 && cachedCount < cacheLastNMessages; i--) {
const msg = messages[i]
if (msg.role === 'system' || cumsumTokens[i] < tokenThreshold || msg.content.length === 0) {
continue
}
const newContent = [...msg.content]
const lastIndex = newContent.length - 1
newContent[lastIndex] = {
...newContent[lastIndex],
providerOptions: cacheProviderOptions
}
messages[i] = {
...msg,
content: newContent
} as LanguageModelV3Message
cachedCount++
}
}
return { ...params, prompt: messages }
}
}
}
export const createAnthropicCachePlugin = (provider: Provider) =>
definePlugin({
name: 'anthropicCache',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(anthropicCacheMiddleware(provider))
}
})

View File

@@ -1,64 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import type { LanguageModelMiddleware } from 'ai'
const logger = loggerService.withContext('noThinkPlugin')
/**
* No Think Middleware
* Automatically appends ' /no_think' string to the end of user messages for the provider
* This prevents the model from generating unnecessary thinking process and returns results directly
* @returns LanguageModelMiddleware
*/
function createNoThinkMiddleware(): LanguageModelMiddleware {
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
// Only process user messages
if (message.role === 'user') {
// Process content array
if (Array.isArray(message.content)) {
const lastContent = message.content[message.content.length - 1]
// If the last content is text type, append ' /no_think'
if (lastContent && lastContent.type === 'text' && typeof lastContent.text === 'string') {
// Avoid duplicate additions
if (!lastContent.text.endsWith('/no_think')) {
logger.debug('Adding /no_think to user message')
return {
...message,
content: [
...message.content.slice(0, -1),
{
...lastContent,
text: lastContent.text + ' /no_think'
}
]
}
}
}
}
}
return message
})
}
return transformedParams
}
}
}
export const createNoThinkPlugin = () =>
definePlugin({
name: 'noThink',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(createNoThinkMiddleware())
}
})

View File

@@ -1,62 +0,0 @@
import type { LanguageModelV3StreamPart } from '@ai-sdk/provider'
import { definePlugin } from '@cherrystudio/ai-core'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
function createOpenrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
specificationVersion: 'v3',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV3StreamPart, LanguageModelV3StreamPart>({
transform(
chunk: LanguageModelV3StreamPart,
controller: TransformStreamDefaultController<LanguageModelV3StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
export const createOpenrouterReasoningPlugin = () =>
definePlugin({
name: 'openrouterReasoning',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(createOpenrouterReasoningMiddleware())
}
})

View File

@@ -1,114 +0,0 @@
/**
* PDF Compatibility Plugin
*
* Converts PDF FileParts to TextParts for providers that don't support native PDF input.
* Extracts text directly from the FilePart's base64 data using pdf-parse.
*/
import type { LanguageModelV3FilePart, LanguageModelV3Message } from '@ai-sdk/provider'
import { definePlugin } from '@cherrystudio/ai-core/core/plugins'
import { loggerService } from '@logger'
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models'
import { isOpenAILLMModel } from '@renderer/config/models/openai'
import type { Model, Provider, ProviderType } from '@renderer/types'
import { extractPdfText } from '@shared/utils/pdf'
import type { LanguageModelMiddleware } from 'ai'
import i18n from 'i18next'
const logger = loggerService.withContext('pdfCompatibilityPlugin')
type ContentPart = Exclude<LanguageModelV3Message['content'], string>[number]
/**
* Provider types whose API natively supports PDF file input.
* Only first-party provider protocols (OpenAI, Anthropic, Google) are included.
* Aggregators (new-api, gateway) and generic 'openai' type are excluded
* because they may route to backends that don't support the 'file' part type.
*/
const PDF_NATIVE_PROVIDER_TYPES = new Set<ProviderType>([
'openai-response', // OpenAI Responses API
'anthropic', // Anthropic API
'gemini', // Google Gemini API
'azure-openai', // Azure OpenAI
'vertexai', // Google Vertex AI
'aws-bedrock', // AWS Bedrock
'vertex-anthropic' // Vertex AI with Anthropic models
])
function isPdfFilePart(part: ContentPart): part is LanguageModelV3FilePart & { mediaType: 'application/pdf' } {
return part.type === 'file' && part.mediaType === 'application/pdf'
}
function supportsNativePdf(provider: Provider, model: Model): boolean {
// OpenAI, Claude, and Gemini models always support native PDF regardless of provider
if (isOpenAILLMModel(model) || isAnthropicModel(model) || isGeminiModel(model)) {
return true
}
if (PDF_NATIVE_PROVIDER_TYPES.has(provider.type)) {
return true
}
// TODO: allow user to configure native pdf compatibility for provider/model
return false
}
function pdfCompatibilityMiddleware(provider: Provider, model: Model): LanguageModelMiddleware {
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
if (supportsNativePdf(provider, model)) {
return params
}
if (!Array.isArray(params.prompt) || params.prompt.length === 0) {
return params
}
const messages: LanguageModelV3Message[] = []
for (const message of params.prompt) {
if (!Array.isArray(message.content)) {
messages.push(message)
continue
}
const hasPdf = message.content.some((part: (typeof message.content)[number]) => isPdfFilePart(part))
if (!hasPdf) {
messages.push(message)
continue
}
const newContent: ContentPart[] = []
for (const part of message.content) {
if (!isPdfFilePart(part)) {
newContent.push(part)
continue
}
const fileName = part.filename || 'PDF'
try {
const textContent =
part.data instanceof URL ? await extractPdfText(part.data) : await window.api.pdf.extractText(part.data)
logger.debug(`Converting PDF FilePart to TextPart for provider ${provider.id} (type: ${provider.type})`)
newContent.push({ type: 'text', text: `${fileName}\n${textContent.trim()}` })
} catch (error) {
logger.warn(`Failed to extract text from PDF ${fileName}:`, error instanceof Error ? error : undefined)
window.toast.warning(i18n.t('message.warning.file.pdf_text_extraction_failed', { name: fileName }))
}
}
messages.push(Object.assign({}, message, { content: newContent }))
}
return { ...params, prompt: messages }
}
}
}
export const createPdfCompatibilityPlugin = (provider: Provider, model: Model) =>
definePlugin({
name: 'pdfCompatibility',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(pdfCompatibilityMiddleware(provider, model))
}
})

View File

@@ -1,54 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import type { LanguageModelMiddleware } from 'ai'
/**
* Qwen Thinking Middleware
* Controls thinking mode for Qwen models on providers that don't support enable_thinking parameter (like Ollama)
* Appends '/think' or '/no_think' suffix to user messages based on reasoning_effort setting
*
* NOTE: Qwen3.5 does not officially support the soft switch of Qwen3, i.e., /think and /nothink.
*
* @param enableThinking - Whether thinking mode is enabled (based on reasoning_effort !== undefined)
* @returns LanguageModelMiddleware
*/
function createQwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware {
const suffix = enableThinking ? ' /think' : ' /no_think'
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
// Only process user messages
if (message.role === 'user') {
// Process content array
if (Array.isArray(message.content)) {
for (const part of message.content) {
if (part.type === 'text' && !part.text.endsWith('/think') && !part.text.endsWith('/no_think')) {
part.text += suffix
}
}
}
}
return message
})
}
return transformedParams
}
}
}
export const createQwenThinkingPlugin = (enableThinking: boolean) =>
definePlugin({
name: 'qwenThinking',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(createQwenThinkingMiddleware(enableThinking))
}
})

View File

@@ -1,22 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import { extractReasoningMiddleware } from 'ai'
/**
* Reasoning Extraction Plugin
* Extracts reasoning/thinking tags from OpenAI/Azure responses
* Uses AI SDK's built-in extractReasoningMiddleware
*/
export const createReasoningExtractionPlugin = (options: { tagName?: string } = {}) =>
definePlugin({
name: 'reasoningExtraction',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(
extractReasoningMiddleware({
tagName: options.tagName || 'thinking'
})
)
}
})

View File

@@ -1,37 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import type { TextStreamPart, ToolSet } from 'ai'
export default definePlugin({
name: 'reasoningTimePlugin',
transformStream: () => () => {
// === 时间跟踪状态 ===
let thinkingStartTime = 0
let accumulatedThinkingContent = ''
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
// === 处理 reasoning 类型 ===
if (chunk.type === 'reasoning-start') {
controller.enqueue(chunk)
thinkingStartTime = performance.now()
} else if (chunk.type === 'reasoning-delta') {
accumulatedThinkingContent += chunk.text
controller.enqueue({
...chunk,
providerMetadata: {
...chunk.providerMetadata,
metadata: {
...chunk.providerMetadata?.metadata,
thinking_millsec: performance.now() - thinkingStartTime,
thinking_content: accumulatedThinkingContent
}
}
})
} else {
controller.enqueue(chunk)
}
}
})
}
})

View File

@@ -1,407 +0,0 @@
/**
* 搜索编排插件
*
* 功能:
* 1. onRequestStart: 智能意图识别 - 分析是否需要网络搜索、知识库搜索、记忆搜索
* 2. transformParams: 根据意图分析结果动态添加对应的工具
* 3. onRequestEnd: 自动记忆存储
*/
import {
type AiPlugin,
type AiRequestContext,
definePlugin,
type StreamTextParams,
type StreamTextResult
} from '@cherrystudio/ai-core'
import { preferenceService } from '@data/PreferenceService'
import { loggerService } from '@logger'
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
import type { ExtractResults } from '@renderer/utils/extract'
import { extractInfoFromXML } from '@renderer/utils/extract'
// import { generateObject } from '@cherrystudio/ai-core'
import {
SEARCH_SUMMARY_PROMPT,
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@shared/config/prompts'
import type { LanguageModel, ModelMessage } from 'ai'
import { generateText } from 'ai'
import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor'
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
import { memorySearchTool } from '../tools/MemorySearchTool'
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
const logger = loggerService.withContext('SearchOrchestrationPlugin')
export const getMessageContent = (message: ModelMessage) => {
if (typeof message.content === 'string') return message.content
return message.content.reduce((acc, part) => {
if (part.type === 'text') {
return acc + part.text + '\n'
}
return acc
}, '')
}
// === Schema Definitions ===
// const WebSearchSchema = z.object({
// question: z
// .array(z.string())
// .describe('Search queries for web search. Use "not_needed" if no web search is required.'),
// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
// })
// const KnowledgeSearchSchema = z.object({
// question: z
// .array(z.string())
// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
// rewrite: z
// .string()
// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
// })
// const SearchIntentAnalysisSchema = z.object({
// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
// })
// type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
// let isAnalyzing = false
/**
* 🧠 意图分析函数 - 使用 XML 解析
*/
async function analyzeSearchIntent(
lastUserMessage: ModelMessage,
assistant: Assistant,
options: {
shouldWebSearch?: boolean
shouldKnowledgeSearch?: boolean
shouldMemorySearch?: boolean
lastAnswer?: ModelMessage
context: AiRequestContext
topicId: string
}
): Promise<ExtractResults | undefined> {
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
if (!lastUserMessage) return undefined
// 根据配置决定是否需要提取
const needWebExtract = shouldWebSearch
const needKnowledgeExtract = shouldKnowledgeSearch
if (!needWebExtract && !needKnowledgeExtract) return undefined
// 选择合适的提示词
let prompt: string
// let schema: z.Schema
if (needWebExtract && !needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
// schema = z.object({ websearch: WebSearchSchema })
} else if (!needWebExtract && needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
// schema = z.object({ knowledge: KnowledgeSearchSchema })
} else {
prompt = SEARCH_SUMMARY_PROMPT
// schema = SearchIntentAnalysisSchema
}
// 构建消息上下文 - 简化逻辑
const chatHistory = lastAnswer ? `assistant: ${getMessageContent(lastAnswer)}` : ''
const question = getMessageContent(lastUserMessage) || ''
// 使用模板替换变量
const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
// 获取模型和provider信息
const model = assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!provider || isEmpty(provider.apiKey)) {
logger.error('Provider not found or missing API key')
return getFallbackResult()
}
try {
logger.info('Starting intent analysis generateText call', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId,
hasWebSearch: needWebExtract,
hasKnowledgeSearch: needKnowledgeExtract
})
const { text: result } = await generateText({
model: context.model as LanguageModel,
prompt: formattedPrompt
}).finally(() => {
logger.info('Intent analysis generateText call completed', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId
})
})
const parsedResult = extractInfoFromXML(result)
logger.debug('Intent analysis result', { parsedResult })
// 根据需求过滤结果
return {
websearch: needWebExtract ? parsedResult?.websearch : undefined,
knowledge: needKnowledgeExtract ? parsedResult?.knowledge : undefined
}
} catch (e: any) {
logger.error('Intent analysis failed', e as Error)
return getFallbackResult()
}
function getFallbackResult(): ExtractResults {
const fallbackContent = getMessageContent(lastUserMessage)
return {
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
knowledge: shouldKnowledgeSearch
? {
question: [fallbackContent || 'search'],
rewrite: fallbackContent || 'search'
}
: undefined
}
}
}
/**
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
*/
async function storeConversationMemory(
messages: ModelMessage[],
assistant: Assistant,
context: AiRequestContext
): Promise<void> {
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
if (!globalMemoryEnabled || !assistant.enableMemory) {
return
}
try {
const memoryConfig = selectMemoryConfig(store.getState())
// 转换消息为记忆处理器期望的格式
const conversationMessages = messages
.filter((msg) => msg.role === 'user' || msg.role === 'assistant')
.map((msg) => ({
role: msg.role,
content: getMessageContent(msg) || ''
}))
.filter((msg) => msg.content.trim().length > 0)
logger.debug('conversationMessages', conversationMessages)
if (conversationMessages.length < 2) {
logger.info('Need at least a user message and assistant response for memory processing')
return
}
const currentUserId = await preferenceService.get('feature.memory.current_user_id')
// const lastUserMessage = messages.findLast((m) => m.role === 'user')
const processorConfig = MemoryProcessor.getProcessorConfig(
memoryConfig,
assistant.id,
currentUserId,
context.requestId
)
logger.info('Processing conversation memory...', { messageCount: conversationMessages.length })
// 后台处理对话记忆(不阻塞 UI
const memoryProcessor = new MemoryProcessor()
memoryProcessor
.processConversation(conversationMessages, processorConfig)
.then((result) => {
logger.info('Memory processing completed:', result)
if (result.facts?.length > 0) {
logger.info('Extracted facts from conversation:', result.facts)
logger.info('Memory operations performed:', result.operations)
} else {
logger.info('No facts extracted from conversation')
}
})
.catch((error) => {
logger.error('Background memory processing failed:', error as Error)
})
} catch (error) {
logger.error('Error in conversation memory processing:', error as Error)
// 不抛出错误,避免影响主流程
}
}
/**
* 🎯 搜索编排插件
*/
export const searchOrchestrationPlugin = (
assistant: Assistant,
topicId: string
): AiPlugin<StreamTextParams, StreamTextResult> => {
// 存储意图分析结果
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
const userMessages: { [requestId: string]: ModelMessage } = {}
return definePlugin<StreamTextParams, StreamTextResult>({
name: 'search-orchestration',
enforce: 'pre', // 确保在其他插件之前执行
/**
* 🔍 Step 1: 意图识别阶段
*/
onRequestStart: async (context) => {
// 没开启任何搜索则不进行意图分析
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
try {
const messages = context.originalParams.messages
if (!messages || messages.length === 0) {
return
}
const lastUserMessage = messages[messages.length - 1]
const lastAssistantMessage = messages.length >= 2 ? messages[messages.length - 2] : undefined
// 存储用户消息用于后续记忆存储
userMessages[context.requestId] = lastUserMessage
// 判断是否需要各种搜索
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
const shouldWebSearch = !!assistant.webSearchProviderId
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
// 执行意图分析
if (shouldWebSearch || shouldKnowledgeSearch) {
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
shouldWebSearch,
shouldKnowledgeSearch,
shouldMemorySearch,
lastAnswer: lastAssistantMessage,
context,
topicId
})
if (analysisResult) {
intentAnalysisResults[context.requestId] = analysisResult
// logger.info('🧠 Intent analysis completed:', analysisResult)
}
}
} catch (error) {
logger.error('🧠 Intent analysis failed:', error as Error)
// 不抛出错误,让流程继续
}
},
/**
* 🔧 Step 2: 工具配置阶段
*/
transformParams: async (params, context) => {
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
try {
const analysisResult = intentAnalysisResults[context.requestId]
// if (!analysisResult || !assistant) {
// logger.info('🔧 No analysis result or assistant, skipping tool configuration')
// return params
// }
// 确保 tools 对象存在
if (!params.tools) {
params.tools = {}
}
// 🌐 网络搜索工具配置
if (analysisResult?.websearch && assistant.webSearchProviderId) {
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
if (needsSearch) {
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
// logger.info('🌐 Adding web search tool with pre-extracted keywords')
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
assistant.webSearchProviderId,
analysisResult.websearch,
context.requestId
)
}
}
// 📚 知识库搜索工具配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
if (shouldKnowledgeSearch) {
// on 模式:根据意图识别结果决定是否添加工具
const needsKnowledgeSearch =
analysisResult?.knowledge &&
analysisResult.knowledge.question &&
analysisResult.knowledge.question[0] !== 'not_needed'
if (needsKnowledgeSearch && analysisResult.knowledge) {
// logger.info('📚 Adding knowledge search tool (intent-based)')
const userMessage = userMessages[context.requestId]
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
assistant,
analysisResult.knowledge,
topicId,
getMessageContent(userMessage)
)
}
}
// 🧠 记忆搜索工具配置
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
if (globalMemoryEnabled && assistant.enableMemory) {
// logger.info('🧠 Adding memory search tool')
params.tools['builtin_memory_search'] = memorySearchTool(assistant.id)
}
// logger.info('🔧 Tools configured:', Object.keys(params.tools))
return params
} catch (error) {
logger.error('🔧 Tool configuration failed:', error as Error)
return params
}
},
/**
* 💾 Step 3: 记忆存储阶段
*/
onRequestEnd: async (context) => {
// context.isAnalyzing = false
// logger.info('context.isAnalyzing', context, result)
// logger.info('💾 Starting memory storage...', context.requestId)
try {
// ✅ 类型安全访问context.originalParams 已通过泛型正确类型化
const messages = context.originalParams.messages
if (messages && assistant) {
await storeConversationMemory(messages, assistant, context)
}
// 清理缓存
delete intentAnalysisResults[context.requestId]
delete userMessages[context.requestId]
} catch (error) {
logger.error('💾 Memory storage failed:', error as Error)
// 不抛出错误,避免影响主流程
}
}
})
}
export default searchOrchestrationPlugin

View File

@@ -1,18 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import { simulateStreamingMiddleware } from 'ai'
/**
* Simulate Streaming Plugin
* Converts non-streaming responses to streaming format
* Uses AI SDK's built-in simulateStreamingMiddleware
*/
export const createSimulateStreamingPlugin = () =>
definePlugin({
name: 'simulateStreaming',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(simulateStreamingMiddleware())
}
})

View File

@@ -1,73 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import type { LanguageModelMiddleware } from 'ai'
const logger = loggerService.withContext('skipGeminiThoughtSignaturePlugin')
/**
* skip Gemini Thought Signature Middleware
*
* Handles:
* - Tool-call parts need thought_signature for OpenAI-compatible API
* -> Add providerOptions.openaiCompatible.extra_content.google.thought_signature
*
* Note: Thought signature for text/reasoning parts is now handled in messageConverter.
*
* @returns LanguageModelMiddleware
*/
function createSkipGeminiThoughtSignatureMiddleware(): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
specificationVersion: 'v3',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
logger.debug('transformedParams', transformedParams)
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const isToolCallPart = part.type === 'tool-call'
// Note: text part and reasoning part do not require thought signature validation
// They are handled by messageConverter now
// Case: OpenAI-compatible path - add extra_content for tool-call parts
// All tool-calls need the signature for Gemini OpenAI-compatible API
if (isToolCallPart) {
if (!part.providerOptions) {
part.providerOptions = {}
}
if (!part.providerOptions.openaiCompatible) {
part.providerOptions.openaiCompatible = {}
}
// Google OpenAI-compatible API expects extra_content.google.thought_signature
// See: https://ai.google.dev/gemini-api/docs/thought-signatures#openai
part.providerOptions.openaiCompatible.extra_content = {
google: {
thought_signature: MAGIC_STRING
}
}
}
}
}
return message
})
}
return transformedParams
}
}
}
export const createSkipGeminiThoughtSignaturePlugin = () =>
definePlugin({
name: 'skipGeminiThoughtSignature',
enforce: 'pre',
configureContext: (context) => {
context.middlewares = context.middlewares || []
context.middlewares.push(createSkipGeminiThoughtSignatureMiddleware())
}
})

View File

@@ -1,442 +0,0 @@
/**
* Telemetry Plugin for AI SDK Integration
*
* 在 transformParams 钩子中注入 experimental_telemetry 参数,
* 实现 AI SDK trace 与现有手动 trace 系统的统一
* 集成 AiSdkSpanAdapter 将 AI SDK trace 数据转换为现有格式
*/
import type { AiPlugin } from '@cherrystudio/ai-core'
import { definePlugin, type StreamTextParams, type StreamTextResult } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import type { Context, Span, SpanContext, Tracer } from '@opentelemetry/api'
import { context as otelContext, trace } from '@opentelemetry/api'
import { currentSpan } from '@renderer/services/SpanManagerService'
import { webTraceService } from '@renderer/services/WebTraceService'
import type { Assistant } from '@renderer/types'
import type { TelemetrySettings } from 'ai'
import { AiSdkSpanAdapter } from '../trace/AiSdkSpanAdapter'
const logger = loggerService.withContext('TelemetryPlugin')
export interface TelemetryPluginConfig {
enabled?: boolean
recordInputs?: boolean
recordOutputs?: boolean
topicId: string
assistant: Assistant
}
/**
* 自定义 Tracer集成适配器转换逻辑
*/
class AdapterTracer {
private originalTracer: Tracer
private topicId?: string
private modelName?: string
private parentSpanContext?: SpanContext
private cachedParentContext?: Context
constructor(originalTracer: Tracer, topicId?: string, modelName?: string, parentSpanContext?: SpanContext) {
this.originalTracer = originalTracer
this.topicId = topicId
this.modelName = modelName
this.parentSpanContext = parentSpanContext
// 预构建一个包含父 SpanContext 的 Context便于复用
try {
this.cachedParentContext = this.parentSpanContext
? trace.setSpanContext(otelContext.active(), this.parentSpanContext)
: undefined
} catch {
this.cachedParentContext = undefined
}
logger.debug('AdapterTracer created with parent context info', {
topicId,
modelName,
parentTraceId: this.parentSpanContext?.traceId,
parentSpanId: this.parentSpanContext?.spanId,
hasOriginalTracer: !!originalTracer
})
}
startSpan(name: string, options?: any, context?: any): Span {
logger.debug('AdapterTracer.startSpan called', {
spanName: name,
topicId: this.topicId,
modelName: this.modelName
})
// 创建包含父 SpanContext 的上下文(如果有的话)
const createContextWithParent = () => {
if (this.cachedParentContext) {
return this.cachedParentContext
}
if (this.parentSpanContext) {
try {
const ctx = trace.setSpanContext(otelContext.active(), this.parentSpanContext)
logger.debug('Created active context with parent SpanContext for startSpan', {
spanName: name,
parentTraceId: this.parentSpanContext.traceId,
parentSpanId: this.parentSpanContext.spanId,
topicId: this.topicId
})
return ctx
} catch (error) {
logger.warn('Failed to create context with parent SpanContext in startSpan', error as Error)
}
}
return otelContext.active()
}
const ctx = context ?? createContextWithParent()
const span = this.originalTracer.startSpan(name, options, ctx)
// 注入父子关系属性(兜底重建层级用)
try {
if (this.parentSpanContext) {
span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
}
if (this.topicId) {
span.setAttribute('trace.topicId', this.topicId)
}
} catch (e) {
logger.debug('Failed to set trace parent attributes in startSpan', e as Error)
}
// 包装span的end方法
const originalEnd = span.end.bind(span)
span.end = (endTime?: any) => {
logger.debug('AI SDK span.end() called in startSpan - about to convert span', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
// 调用原始 end 方法
originalEnd(endTime)
// 转换并保存 span 数据
try {
logger.debug('Converting AI SDK span to SpanEntity (from startSpan)', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
logger.silly('span', span)
const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
span,
topicId: this.topicId,
modelName: this.modelName
})
// 保存转换后的数据
void window.api.trace.saveEntity(spanEntity)
logger.debug('AI SDK span converted and saved successfully (from startSpan)', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName,
hasUsage: !!spanEntity.usage,
usage: spanEntity.usage
})
} catch (error) {
logger.error('Failed to convert AI SDK span (from startSpan)', error as Error, {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
}
}
return span
}
startActiveSpan<F extends (span: Span) => any>(name: string, fn: F): ReturnType<F>
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, fn: F): ReturnType<F>
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, context: any, fn: F): ReturnType<F>
startActiveSpan<F extends (span: Span) => any>(name: string, arg2?: any, arg3?: any, arg4?: any): ReturnType<F> {
logger.debug('AdapterTracer.startActiveSpan called', {
spanName: name,
topicId: this.topicId,
modelName: this.modelName,
// oxlint-disable-next-line no-undef False alarm. see https://github.com/oxc-project/oxc/issues/4232
argCount: arguments.length
})
// 包装函数来添加span转换逻辑
const wrapFunction = (originalFn: F, span: Span): F => {
const wrappedFn = ((passedSpan: Span) => {
// 注入父子关系属性(兜底重建层级用)
try {
if (this.parentSpanContext) {
passedSpan.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
passedSpan.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
}
if (this.topicId) {
passedSpan.setAttribute('trace.topicId', this.topicId)
}
} catch (e) {
logger.debug('Failed to set trace parent attributes in startActiveSpan', e as Error)
}
// 包装span的end方法
const originalEnd = span.end.bind(span)
span.end = (endTime?: any) => {
logger.debug('AI SDK span.end() called in startActiveSpan - about to convert span', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
// 调用原始 end 方法
originalEnd(endTime)
// 转换并保存 span 数据
try {
logger.debug('Converting AI SDK span to SpanEntity (from startActiveSpan)', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
logger.silly('span', span)
const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
span,
topicId: this.topicId,
modelName: this.modelName
})
// 保存转换后的数据
void window.api.trace.saveEntity(spanEntity)
logger.debug('AI SDK span converted and saved successfully (from startActiveSpan)', {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName,
hasUsage: !!spanEntity.usage,
usage: spanEntity.usage
})
} catch (error) {
logger.error('Failed to convert AI SDK span (from startActiveSpan)', error as Error, {
spanName: name,
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: this.topicId,
modelName: this.modelName
})
}
}
return originalFn(passedSpan)
}) as F
return wrappedFn
}
// 创建包含父 SpanContext 的上下文(如果有的话)
const createContextWithParent = () => {
if (this.cachedParentContext) {
return this.cachedParentContext
}
if (this.parentSpanContext) {
try {
const ctx = trace.setSpanContext(otelContext.active(), this.parentSpanContext)
logger.debug('Created active context with parent SpanContext for startActiveSpan', {
spanName: name,
parentTraceId: this.parentSpanContext.traceId,
parentSpanId: this.parentSpanContext.spanId,
topicId: this.topicId
})
return ctx
} catch (error) {
logger.warn('Failed to create context with parent SpanContext in startActiveSpan', error as Error)
}
}
return otelContext.active()
}
// 根据参数数量确定调用方式注入包含mainTraceId的上下文
if (typeof arg2 === 'function') {
return this.originalTracer.startActiveSpan(name, {}, createContextWithParent(), (span: Span) => {
return wrapFunction(arg2, span)(span)
})
} else if (typeof arg3 === 'function') {
return this.originalTracer.startActiveSpan(name, arg2, createContextWithParent(), (span: Span) => {
return wrapFunction(arg3, span)(span)
})
} else if (typeof arg4 === 'function') {
// 如果调用方提供了 context则保留以维护嵌套关系否则回退到父上下文
const ctx = arg3 ?? createContextWithParent()
return this.originalTracer.startActiveSpan(name, arg2, ctx, (span: Span) => {
return wrapFunction(arg4, span)(span)
})
} else {
throw new Error('Invalid arguments for startActiveSpan')
}
}
}
export function createTelemetryPlugin(config: TelemetryPluginConfig): AiPlugin<StreamTextParams, StreamTextResult> {
const { enabled = true, recordInputs = true, recordOutputs = true, topicId } = config
return definePlugin<StreamTextParams, StreamTextResult>({
name: 'telemetryPlugin',
enforce: 'pre', // 在其他插件之前执行,确保 telemetry 配置被正确注入
transformParams: (params, context) => {
if (!enabled) {
return params
}
// 获取共享的 tracer
const originalTracer = webTraceService.getTracer()
if (!originalTracer) {
logger.warn('No tracer available from WebTraceService')
return params
}
// 获取topicId和modelName
const effectiveTopicId = context.topicId || topicId
// 使用与父span创建时一致的modelName - 应该是完整的modelId
const modelName = config.assistant.model?.name || context.modelId
// 获取当前活跃的 span确保 AI SDK spans 与手动 spans 在同一个 trace 中
let parentSpan: Span | undefined = undefined
let parentSpanContext: SpanContext | undefined = undefined
// 只有在有topicId时才尝试查找父span
if (effectiveTopicId) {
try {
// 从 SpanManagerService 获取当前的 span
logger.debug('Attempting to find parent span', {
topicId: effectiveTopicId,
requestId: context.requestId,
modelName: modelName,
contextModelId: context.modelId,
providerId: context.providerId
})
parentSpan = currentSpan(effectiveTopicId, modelName)
if (parentSpan) {
// 直接使用父 span 的 SpanContext避免手动拼装字段遗漏
parentSpanContext = parentSpan.spanContext()
logger.debug('Found active parent span for AI SDK', {
parentSpanId: parentSpanContext.spanId,
parentTraceId: parentSpanContext.traceId,
topicId: effectiveTopicId,
requestId: context.requestId,
modelName: modelName
})
} else {
logger.warn('No active parent span found in SpanManagerService', {
topicId: effectiveTopicId,
requestId: context.requestId,
modelId: context.modelId,
modelName: modelName,
providerId: context.providerId,
// 更详细的调试信息
searchedModelName: modelName,
contextModelId: context.modelId,
isAnalyzing: context.isAnalyzing
})
}
} catch (error) {
logger.error('Error getting current span from SpanManagerService', error as Error, {
topicId: effectiveTopicId,
requestId: context.requestId,
modelName: modelName
})
}
} else {
logger.debug('No topicId provided, skipping parent span lookup', {
requestId: context.requestId,
contextTopicId: context.topicId,
configTopicId: topicId,
modelName: modelName
})
}
// 创建适配器包装的 tracer传入获取到的父 SpanContext
const adapterTracer = new AdapterTracer(originalTracer, effectiveTopicId, modelName, parentSpanContext)
// 注入 AI SDK telemetry 配置
const telemetryConfig = {
isEnabled: true,
recordInputs,
recordOutputs,
tracer: adapterTracer,
functionId: `ai-request-${context.requestId}`,
metadata: {
providerId: context.providerId,
modelId: context.modelId,
topicId: effectiveTopicId,
requestId: context.requestId,
modelName: modelName,
// 确保topicId也作为标准属性传递
'trace.topicId': effectiveTopicId,
'trace.modelName': modelName,
// 添加父span信息用于调试只在有值时添加
...(parentSpanContext?.spanId && { parentSpanId: parentSpanContext.spanId }),
...(parentSpanContext?.traceId && { parentTraceId: parentSpanContext.traceId })
}
} satisfies TelemetrySettings
// 如果有父span尝试在telemetry配置中设置父上下文
if (parentSpan) {
try {
// 设置活跃上下文,确保 AI SDK spans 在正确的 trace 上下文中创建
const activeContext = trace.setSpan(otelContext.active(), parentSpan)
// 更新全局上下文
otelContext.with(activeContext, () => {
logger.debug('Updated active context with parent span')
})
logger.debug('Set parent context for AI SDK spans', {
parentSpanId: parentSpanContext?.spanId,
parentTraceId: parentSpanContext?.traceId,
hasActiveContext: !!activeContext,
hasParentSpan: !!parentSpan
})
} catch (error) {
logger.warn('Failed to set parent context in telemetry config', error as Error)
}
}
logger.debug('Injecting AI SDK telemetry config with adapter', {
requestId: context.requestId,
topicId: effectiveTopicId,
modelId: context.modelId,
modelName: modelName,
hasParentSpan: !!parentSpan,
parentSpanId: parentSpanContext?.spanId,
parentTraceId: parentSpanContext?.traceId,
functionId: telemetryConfig.functionId,
hasTracer: !!telemetryConfig.tracer,
tracerType: telemetryConfig.tracer?.constructor?.name || 'unknown'
})
return {
...params,
experimental_telemetry: telemetryConfig
}
}
})
}
// 默认导出便于使用
export default createTelemetryPlugin

View File

@@ -1,838 +0,0 @@
import type { Message, Model } from '@renderer/types'
import type { FileMetadata } from '@renderer/types/file'
import { FILE_TYPE } from '@renderer/types/file'
import {
AssistantMessageStatus,
type FileMessageBlock,
type ImageMessageBlock,
type MainTextMessageBlock,
MessageBlockStatus,
MessageBlockType,
type ThinkingMessageBlock,
UserMessageStatus
} from '@renderer/types/newMessage'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
convertFileBlockToFilePartMock: vi.fn(),
convertFileBlockToTextPartMock: vi.fn()
}))
vi.mock('../fileProcessor', () => ({
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
convertFileBlockToTextPart: convertFileBlockToTextPartMock
}))
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
vi.mock('@renderer/config/models', () => ({
isVisionModel: (model: Model) => visionModelIds.has(model.id),
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
}))
type MockableMessage = Message & {
__mockContent?: string
__mockFileBlocks?: FileMessageBlock[]
__mockImageBlocks?: ImageMessageBlock[]
__mockThinkingBlocks?: ThinkingMessageBlock[]
__mockMainTextBlocks?: MainTextMessageBlock[]
}
vi.mock('@renderer/utils/messageUtils/find', () => ({
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? [],
findMainTextBlocks: (message: Message) => (message as MockableMessage).__mockMainTextBlocks ?? []
}))
import { convertMessagesToSdkMessages, convertMessageToSdkParam, stripMarkdownBase64Images } from '../messageConverter'
let messageCounter = 0
let blockCounter = 0
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'openai',
...overrides
})
const createMessage = (role: Message['role']): MockableMessage =>
({
id: `message-${++messageCounter}`,
role,
assistantId: 'assistant-1',
topicId: 'topic-1',
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
blocks: []
}) as MockableMessage
const createFileBlock = (
messageId: string,
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
): FileMessageBlock => {
const { file, ...blockOverrides } = overrides
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
return {
id: blockOverrides.id ?? `file-block-${blockCounter}`,
messageId,
type: MessageBlockType.FILE,
createdAt: blockOverrides.createdAt ?? timestamp,
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
file: {
id: file?.id ?? `file-${blockCounter}`,
name: file?.name ?? 'document.txt',
origin_name: file?.origin_name ?? 'document.txt',
path: file?.path ?? '/tmp/document.txt',
size: file?.size ?? 1024,
ext: file?.ext ?? '.txt',
type: file?.type ?? FILE_TYPE.TEXT,
created_at: file?.created_at ?? timestamp,
count: file?.count ?? 1,
...file
},
...blockOverrides
}
}
const createImageBlock = (
messageId: string,
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
): ImageMessageBlock => ({
id: overrides.id ?? `image-block-${++blockCounter}`,
messageId,
type: MessageBlockType.IMAGE,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
url: overrides.url ?? 'https://example.com/image.png',
...overrides
})
const createThinkingBlock = (
messageId: string,
overrides: Partial<Omit<ThinkingMessageBlock, 'type' | 'messageId'>> = {}
): ThinkingMessageBlock => ({
id: overrides.id ?? `thinking-block-${++blockCounter}`,
messageId,
type: MessageBlockType.THINKING,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
content: overrides.content ?? 'Let me think...',
thinking_millsec: overrides.thinking_millsec ?? 1000,
...overrides
})
const createMainTextBlock = (
messageId: string,
overrides: Partial<Omit<MainTextMessageBlock, 'type' | 'messageId'>> = {}
): MainTextMessageBlock => ({
id: overrides.id ?? `main-text-block-${++blockCounter}`,
messageId,
type: MessageBlockType.MAIN_TEXT,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
content: overrides.content ?? '',
...overrides
})
describe('messageConverter', () => {
beforeEach(() => {
convertFileBlockToFilePartMock.mockReset()
convertFileBlockToTextPartMock.mockReset()
convertFileBlockToFilePartMock.mockResolvedValue(null)
convertFileBlockToTextPartMock.mockResolvedValue(null)
messageCounter = 0
blockCounter = 0
})
describe('convertMessageToSdkParam', () => {
it('includes text and image parts for user messages on vision models', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Describe this picture'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Describe this picture' },
{ type: 'image', image: 'https://example.com/cat.png' }
]
})
})
it('extracts base64 data from data URLs and preserves mediaType', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Check this image'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:image/png;base64,iVBORw0KGgoAAAANS' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Check this image' },
{ type: 'image', image: 'iVBORw0KGgoAAAANS', mediaType: 'image/png' }
]
})
})
it('handles data URLs without mediaType gracefully', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Check this'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:;base64,AAABBBCCC' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Check this' },
{ type: 'image', image: 'AAABBBCCC' }
]
})
})
it('skips malformed data URLs without comma separator', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Malformed data url'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:image/pngAAABBB' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Malformed data url' }
// Malformed data URL is excluded from the content
]
})
})
it('handles multiple large base64 images without stack overflow', async () => {
const model = createModel()
const message = createMessage('user')
// Create large base64 strings (~500KB each) to simulate real-world large images
const largeBase64 = 'A'.repeat(500_000)
message.__mockContent = 'Check these images'
message.__mockImageBlocks = [
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }),
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }),
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` })
]
// Should not throw RangeError: Maximum call stack size exceeded
await expect(convertMessageToSdkParam(message, true, model)).resolves.toBeDefined()
})
it('returns file instructions as a system message when native uploads succeed', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Summarize the PDF'
message.__mockFileBlocks = [createFileBlock(message.id)]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'document.pdf',
mediaType: 'application/pdf',
data: 'fileid://remote-file'
})
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual([
{
role: 'system',
content: 'fileid://remote-file'
},
{
role: 'user',
content: [{ type: 'text', text: 'Summarize the PDF' }]
}
])
})
it('includes reasoning parts for assistant messages with thinking blocks', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = 'Here is my answer'
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Let me think...' })]
const result = await convertMessageToSdkParam(message, false, model)
// Reasoning blocks must come before text blocks (required by AWS Bedrock for Claude extended thinking)
expect(result).toEqual({
role: 'assistant',
content: [
{ type: 'reasoning', text: 'Let me think...' },
{ type: 'text', text: 'Here is my answer' }
]
})
})
it('excludes empty content from assistant messages', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = ''
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking only' })]
const result = await convertMessageToSdkParam(message, false, model)
// Empty content should not create a text block
expect(result).toEqual({
role: 'assistant',
content: [{ type: 'reasoning', text: 'Thinking only' }]
})
})
it('excludes whitespace-only content from assistant messages', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = ' \n\t '
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking only' })]
const result = await convertMessageToSdkParam(message, false, model)
// Whitespace-only content should not create a text block
expect(result).toEqual({
role: 'assistant',
content: [{ type: 'reasoning', text: 'Thinking only' }]
})
})
it('trims content in assistant messages', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = ' Trimmed answer \n'
message.__mockThinkingBlocks = []
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual({
role: 'assistant',
content: [{ type: 'text', text: 'Trimmed answer' }]
})
})
it('includes thoughtSignature in providerOptions for Gemini thought signature persistence', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = 'Here is my answer'
message.__mockMainTextBlocks = [
createMainTextBlock(message.id, {
content: 'Here is my answer',
metadata: { thoughtSignature: 'test-thought-signature-token' }
})
]
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual({
role: 'assistant',
content: [
{
type: 'text',
text: 'Here is my answer',
providerOptions: {
google: {
thoughtSignature: 'test-thought-signature-token'
}
}
}
]
})
})
it('does not include providerOptions when no thoughtSignature is present', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = 'Plain answer'
message.__mockMainTextBlocks = [createMainTextBlock(message.id, { content: 'Plain answer' })]
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual({
role: 'assistant',
content: [{ type: 'text', text: 'Plain answer' }]
})
})
it('uses thoughtSignature from the first matching MainTextBlock when multiple exist', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = 'Answer text'
message.__mockMainTextBlocks = [
createMainTextBlock(message.id, { content: 'Answer text', metadata: { thoughtSignature: 'first-signature' } }),
createMainTextBlock(message.id, {
content: 'Another block',
metadata: { thoughtSignature: 'second-signature' }
})
]
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual({
role: 'assistant',
content: [
{
type: 'text',
text: 'Answer text',
providerOptions: {
google: {
thoughtSignature: 'first-signature'
}
}
}
]
})
})
it('combines reasoning blocks with thoughtSignature text part', async () => {
const model = createModel()
const message = createMessage('assistant')
message.__mockContent = 'Final answer'
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking step' })]
message.__mockMainTextBlocks = [
createMainTextBlock(message.id, {
content: 'Final answer',
metadata: { thoughtSignature: 'sig-with-reasoning' }
})
]
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual({
role: 'assistant',
content: [
{ type: 'reasoning', text: 'Thinking step' },
{
type: 'text',
text: 'Final answer',
providerOptions: {
google: {
thoughtSignature: 'sig-with-reasoning'
}
}
}
]
})
})
})
describe('convertMessagesToSdkMessages', () => {
it('preserves conversation history and merges images for image enhancement models', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const initialUser = createMessage('user')
initialUser.__mockContent = 'Start editing'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the current preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Increase the brightness'
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
// Preserves all conversation history, only merges images into the last user message
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start editing' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the current preview' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Increase the brightness' },
{ type: 'image', image: 'https://example.com/preview.png' }
]
}
])
})
it('preserves system messages and conversation history for enhancement payloads', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const fileUser = createMessage('user')
fileUser.__mockContent = 'Use this document as inspiration'
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FILE_TYPE.DOCUMENT } })]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'reference.pdf',
mediaType: 'application/pdf',
data: 'fileid://reference'
})
const assistant = createMessage('assistant')
assistant.__mockContent = 'Generated previews ready'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Apply the edits'
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
// Preserves system message, conversation history, and merges images into the last user message
expect(result).toEqual([
{ role: 'system', content: 'fileid://reference' },
{
role: 'user',
content: [{ type: 'text', text: 'Use this document as inspiration' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Generated previews ready' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Apply the edits' },
{ type: 'image', image: 'https://example.com/reference.png' }
]
}
])
})
it('returns messages as-is when no previous assistant message with images', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const user1 = createMessage('user')
user1.__mockContent = 'Start'
const user2 = createMessage('user')
user2.__mockContent = 'Continue without images'
const result = await convertMessagesToSdkMessages([user1, user2], model)
// No images to merge, returns all messages as-is
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start' }]
},
{
role: 'user',
content: [{ type: 'text', text: 'Continue without images' }]
}
])
})
it('returns messages as-is when assistant message has no images', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const user1 = createMessage('user')
user1.__mockContent = 'Start'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Text only response'
assistant.__mockImageBlocks = []
const user2 = createMessage('user')
user2.__mockContent = 'Follow up'
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
// No images to merge, returns all messages as-is
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Text only response' }]
},
{
role: 'user',
content: [{ type: 'text', text: 'Follow up' }]
}
])
})
it('merges images from the most recent assistant message', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const user1 = createMessage('user')
user1.__mockContent = 'Start'
const assistant1 = createMessage('assistant')
assistant1.__mockContent = 'First response'
assistant1.__mockImageBlocks = [createImageBlock(assistant1.id, { url: 'https://example.com/old.png' })]
const user2 = createMessage('user')
user2.__mockContent = 'Continue'
const assistant2 = createMessage('assistant')
assistant2.__mockContent = 'Second response'
assistant2.__mockImageBlocks = [createImageBlock(assistant2.id, { url: 'https://example.com/new.png' })]
const user3 = createMessage('user')
user3.__mockContent = 'Final request'
const result = await convertMessagesToSdkMessages([user1, assistant1, user2, assistant2, user3], model)
// Preserves all history, merges only the most recent assistant's images
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'First response' }]
},
{
role: 'user',
content: [{ type: 'text', text: 'Continue' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Second response' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Final request' },
{ type: 'image', image: 'https://example.com/new.png' }
]
}
])
})
it('returns messages as-is when conversation ends with assistant message', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const user = createMessage('user')
user.__mockContent = 'Start'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Response with image'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/image.png' })]
const result = await convertMessagesToSdkMessages([user, assistant], model)
// The user message is the last user message, but since the assistant comes after,
// there's no "previous" assistant message (search starts from messages.length-2 backwards)
// So no images to merge, returns all messages as-is
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Response with image' }]
}
])
})
it('merges images even when last user message has empty content', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const user1 = createMessage('user')
user1.__mockContent = 'Start'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const user2 = createMessage('user')
user2.__mockContent = ''
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
// Preserves history, merges images into last user message (even if empty)
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Start' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the preview' }]
},
{
role: 'user',
content: [{ type: 'image', image: 'https://example.com/preview.png' }]
}
])
})
it('strips inline base64 data URIs from assistant text to prevent HTTP 413 (#12602)', async () => {
const model = createModel({ id: 'gpt-4o-mini' })
const user1 = createMessage('user')
user1.__mockContent = 'Generate an image of a cat'
const assistant = createMessage('assistant')
assistant.__mockContent =
'Here is the image you requested:\n![cat](data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAA...VeryLongBase64String)\nHope you like it!'
const user2 = createMessage('user')
user2.__mockContent = 'Now describe what you see'
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
const assistantMsg = result.find((m) => m.role === 'assistant')!
const textPart = (assistantMsg.content as Array<{ type: string; text: string }>).find((p) => p.type === 'text')!
// The base64 data URI should be replaced with a placeholder
expect(textPart.text).not.toContain('data:image/')
expect(textPart.text).toContain('![cat](image)')
expect(textPart.text).toContain('Hope you like it!')
})
it('strips multiple inline base64 images from assistant text', async () => {
const model = createModel({ id: 'gpt-4o-mini' })
const user = createMessage('user')
user.__mockContent = 'Generate two images'
const assistant = createMessage('assistant')
assistant.__mockContent = '![first](data:image/png;base64,AAABBB) and ![second](data:image/jpeg;base64,CCCDDD)'
const result = await convertMessageToSdkParam(assistant, false, model)
const textPart = ((result as any).content as Array<{ type: string; text: string }>).find(
(p) => p.type === 'text'
)!
expect(textPart.text).toBe('![first](image) and ![second](image)')
})
it('preserves regular markdown images (non-base64) in assistant text', async () => {
const model = createModel({ id: 'gpt-4o-mini' })
const assistant = createMessage('assistant')
assistant.__mockContent = 'Check this out: ![photo](https://example.com/photo.png)'
const result = await convertMessageToSdkParam(assistant, false, model)
const textPart = ((result as any).content as Array<{ type: string; text: string }>).find(
(p) => p.type === 'text'
)!
expect(textPart.text).toBe('Check this out: ![photo](https://example.com/photo.png)')
})
it('allows using LLM conversation context for image generation', async () => {
// This test verifies the key use case: switching from LLM to image enhancement model
// and using the previous conversation as context for image generation
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
// Simulate a conversation that started with a regular LLM
const user1 = createMessage('user')
user1.__mockContent = 'Help me design a futuristic robot with blue lights'
const assistant1 = createMessage('assistant')
assistant1.__mockContent =
'Great idea! The robot could have a sleek metallic body with glowing blue LED strips...'
assistant1.__mockImageBlocks = [] // LLM response, no images
const user2 = createMessage('user')
user2.__mockContent = 'Yes, and add some chrome accents'
const assistant2 = createMessage('assistant')
assistant2.__mockContent = 'Perfect! Chrome accents would complement the blue lights beautifully...'
assistant2.__mockImageBlocks = [] // Still LLM response, no images
// User switches to image enhancement model and asks for image generation
const user3 = createMessage('user')
user3.__mockContent = 'Now generate an image based on our discussion'
const result = await convertMessagesToSdkMessages([user1, assistant1, user2, assistant2, user3], model)
// All conversation history should be preserved for context
// No images to merge since previous assistant had no images
expect(result).toEqual([
{
role: 'user',
content: [{ type: 'text', text: 'Help me design a futuristic robot with blue lights' }]
},
{
role: 'assistant',
content: [
{
type: 'text',
text: 'Great idea! The robot could have a sleek metallic body with glowing blue LED strips...'
}
]
},
{
role: 'user',
content: [{ type: 'text', text: 'Yes, and add some chrome accents' }]
},
{
role: 'assistant',
content: [{ type: 'text', text: 'Perfect! Chrome accents would complement the blue lights beautifully...' }]
},
{
role: 'user',
content: [{ type: 'text', text: 'Now generate an image based on our discussion' }]
}
])
})
})
describe('stripMarkdownBase64Images', () => {
it('replaces a single base64 image with placeholder', () => {
const input = 'Here is the image:\n![cat](data:image/jpeg;base64,/9j/4AAQ)\nDone.'
expect(stripMarkdownBase64Images(input)).toBe('Here is the image:\n![cat](image)\nDone.')
})
it('replaces multiple base64 images', () => {
const input = '![a](data:image/png;base64,AAA) text ![b](data:image/jpeg;base64,BBB)'
expect(stripMarkdownBase64Images(input)).toBe('![a](image) text ![b](image)')
})
it('preserves regular markdown images with http URLs', () => {
const input = '![photo](https://example.com/photo.png)'
expect(stripMarkdownBase64Images(input)).toBe(input)
})
it('preserves file:// URLs in markdown images', () => {
const input = '![saved](file:///tmp/image.png)'
expect(stripMarkdownBase64Images(input)).toBe(input)
})
it('handles empty alt text', () => {
const input = '![](data:image/png;base64,AAABBB)'
expect(stripMarkdownBase64Images(input)).toBe('![](image)')
})
it('handles text with no markdown images', () => {
expect(stripMarkdownBase64Images('Just plain text.')).toBe('Just plain text.')
})
it('returns empty string for empty input', () => {
expect(stripMarkdownBase64Images('')).toBe('')
})
it('handles mixed base64 and regular images', () => {
const input =
'![a](https://example.com/a.png) then ![b](data:image/png;base64,XXX) then ![c](https://example.com/c.png)'
expect(stripMarkdownBase64Images(input)).toBe(
'![a](https://example.com/a.png) then ![b](image) then ![c](https://example.com/c.png)'
)
})
it('handles data URI without base64 encoding', () => {
const input = '![svg](data:image/svg+xml,%3Csvg%3E%3C/svg%3E)'
expect(stripMarkdownBase64Images(input)).toBe('![svg](image)')
})
it('does not treat bare ](data: without ![ as markdown image', () => {
const input = 'some text ](data:image/png;base64,AAA) more text'
expect(stripMarkdownBase64Images(input)).toBe(input)
})
it('handles large base64 payload without OOM', () => {
const largeBase64 = 'A'.repeat(5_000_000)
const input = `![big](data:image/png;base64,${largeBase64})`
expect(stripMarkdownBase64Images(input)).toBe('![big](image)')
})
it('handles unclosed parenthesis gracefully', () => {
const input = '![broken](data:image/png;base64,AAA'
expect(stripMarkdownBase64Images(input)).toBe(input)
})
})
})

View File

@@ -1,317 +0,0 @@
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
import { TopicType } from '@renderer/types'
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
import { describe, expect, it, vi } from 'vitest'
import { getMaxTokens, getTemperature, getTimeout, getTopP } from '../modelParameters'
vi.mock('@renderer/services/AssistantService', () => ({
DEFAULT_ASSISTANT_SETTINGS: {
maxTokens: 4096,
enableMaxTokens: false,
temperature: 0.7,
enableTemperature: true,
topP: 1,
enableTopP: false,
contextCount: 4096,
streamOutput: true,
defaultModel: undefined,
customParameters: [],
reasoning_effort: 'default',
qwenThinkMode: undefined,
toolUseMode: 'function',
maxToolCalls: 20,
enableMaxToolCalls: true
},
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
contextCount: assistant.settings?.contextCount ?? 4096,
temperature: assistant.settings?.temperature ?? 0.7,
enableTemperature: assistant.settings?.enableTemperature ?? true,
topP: assistant.settings?.topP ?? 1,
enableTopP: assistant.settings?.enableTopP ?? false,
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
maxTokens: assistant.settings?.maxTokens,
streamOutput: assistant.settings?.streamOutput ?? true,
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
defaultModel: assistant.defaultModel,
customParameters: assistant.settings?.customParameters ?? [],
reasoning_effort: assistant.settings?.reasoning_effort ?? 'default',
qwenThinkMode: assistant.settings?.qwenThinkMode
}),
getProviderByModel: (model: Model) => ({ id: model.provider, type: model.provider, models: [] })
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(),
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
}))
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/assistants', () => ({
default: (state = { assistants: [] }) => state
}))
const createTopic = (assistantId: string): Topic => ({
id: `topic-${assistantId}`,
assistantId,
name: 'topic',
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messages: [],
type: TopicType.Chat
})
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
const assistantId = 'assistant-1'
return {
id: assistantId,
name: 'Test Assistant',
prompt: 'prompt',
topics: [createTopic(assistantId)],
type: 'assistant',
settings
}
}
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
provider: 'openai',
name: 'GPT-4o',
group: 'openai',
...overrides
})
describe('modelParameters', () => {
describe('getTemperature', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'medium', enableTemperature: true })
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns temperature when reasoning effort is default for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'default', enableTemperature: true, temperature: 0.7 })
const model = createModel({ id: 'claude-sonnet-4.5', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBe(0.7)
})
it('returns temperature when reasoning effort is none for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'none', enableTemperature: true, temperature: 0.5 })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBe(0.5)
})
it('returns undefined for models without temperature/topP support', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
const model = createModel({
id: 'claude-sonnet-4.5',
name: 'Claude Sonnet 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns configured temperature when enabled', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(0.42)
})
it('returns undefined when temperature is disabled', () => {
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('clamps temperature to max 1.0 for Zhipu models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Anthropic models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
const model = createModel({
id: 'claude-sonnet-3.5',
name: 'Claude 3.5 Sonnet',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Moonshot models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({
id: 'moonshot-v1-8k',
name: 'Moonshot v1 8k',
provider: 'moonshot',
group: 'moonshot'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('does not clamp temperature for OpenAI models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(2.0)
})
it('does not clamp temperature when it is already within limits', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(0.8)
})
})
describe('getTopP', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'high' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for models without TopP support', () => {
const assistant = createAssistant({ enableTopP: true })
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({
id: 'claude-opus-4.5',
name: 'Claude Opus 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns configured TopP when enabled', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBe(0.73)
})
it('returns undefined when TopP is disabled', () => {
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('clamps topP to [0.95, 1] for Claude reasoning models with reasoning effort', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'high' })
const model = createModel({ id: 'claude-sonnet-4.5', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBe(0.95)
})
it('does not clamp topP when reasoning effort is default for Claude models', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'default' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBe(0.5)
})
it('does not clamp topP when reasoning effort is none for Claude models', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'none' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBe(0.5)
})
it('keeps topP unchanged when already in [0.95, 1] range for Claude reasoning models', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.97, reasoning_effort: 'medium' })
const model = createModel({ id: 'claude-sonnet-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBe(0.97)
})
})
describe('getTimeout', () => {
it('uses an extended timeout for flex service tier models', () => {
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(15 * 1000 * 60)
})
it('falls back to the default timeout otherwise', () => {
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(DEFAULT_TIMEOUT)
})
})
describe('getMaxTokens', () => {
it('returns undefined when maxTokens is not enabled', () => {
const assistant = createAssistant({ enableMaxTokens: false, maxTokens: 128000 })
const model = createModel({ id: 'claude-opus-4-6', provider: 'anthropic', group: 'claude' })
expect(getMaxTokens(assistant, model)).toBeUndefined()
})
it('returns user-configured maxTokens for Claude 4.6 without subtraction', () => {
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 128000 })
const model = createModel({ id: 'claude-opus-4-6', provider: 'anthropic', group: 'claude' })
expect(getMaxTokens(assistant, model)).toBe(128000)
})
it('returns user-configured maxTokens for Claude Sonnet 4.6 without subtraction', () => {
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 64000 })
const model = createModel({ id: 'claude-sonnet-4-6', provider: 'anthropic', group: 'claude' })
expect(getMaxTokens(assistant, model)).toBe(64000)
})
it('subtracts thinking budget for non-4.6 Claude models with anthropic provider', () => {
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 16384 })
const model = createModel({ id: 'claude-sonnet-4', provider: 'anthropic', group: 'claude' })
const result = getMaxTokens(assistant, model)
// Non-4.6 Claude thinking models should have budget subtracted
expect(result).toBeDefined()
expect(result!).toBeLessThan(16384)
})
it('returns maxTokens as-is for non-Claude models', () => {
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 4096 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getMaxTokens(assistant, model)).toBe(4096)
})
})
})

View File

@@ -1,225 +0,0 @@
/**
* Tests for parameterBuilder maxToolCalls functionality
* These tests verify the maxToolCalls calculation and validation logic in isolation
*/
import { describe, expect, it } from 'vitest'
// Mirror the constants from parameterBuilder.ts
const MIN_TOOL_CALLS = 1
const MAX_TOOL_CALLS = 100
const DEFAULT_MAX_TOOL_CALLS = 20
const DEFAULT_ENABLE_MAX_TOOL_CALLS = true
/**
* Validates and clamps maxToolCalls to valid range
* Mirrors the logic in parameterBuilder.ts
*/
function validateMaxToolCalls(value: number | undefined): number {
if (value === undefined || value < MIN_TOOL_CALLS || value > MAX_TOOL_CALLS) {
return DEFAULT_MAX_TOOL_CALLS
}
return value
}
/**
* Calculate the effective max tool calls based on assistant settings
* Mirrors the logic in parameterBuilder.ts
*/
function calculateEffectiveMaxToolCalls(settings?: { maxToolCalls?: number; enableMaxToolCalls?: boolean }): {
stopWhen: number | null
maxToolCalls: number
} {
const enableMaxToolCalls = settings?.enableMaxToolCalls ?? DEFAULT_ENABLE_MAX_TOOL_CALLS
if (!enableMaxToolCalls) {
// When disabled, don't pass stopWhen (return null to indicate no stopWhen)
return { stopWhen: null, maxToolCalls: DEFAULT_MAX_TOOL_CALLS }
}
// When enabled, validate and use user-defined value
const maxToolCalls = validateMaxToolCalls(settings?.maxToolCalls)
return { stopWhen: maxToolCalls, maxToolCalls }
}
describe('validateMaxToolCalls', () => {
it('returns valid values as-is', () => {
expect(validateMaxToolCalls(1)).toBe(1)
expect(validateMaxToolCalls(50)).toBe(50)
expect(validateMaxToolCalls(100)).toBe(100)
})
it('clamps values above 100 to default', () => {
const result = validateMaxToolCalls(999)
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('clamps zero to default', () => {
const result = validateMaxToolCalls(0)
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('clamps negative values to default', () => {
const result = validateMaxToolCalls(-5)
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('returns default when value is undefined', () => {
const result = validateMaxToolCalls(undefined)
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
})
})
describe('maxToolCalls calculation logic', () => {
describe('default behavior', () => {
it('uses default value 20 when settings are undefined', () => {
const result = calculateEffectiveMaxToolCalls(undefined)
expect(result.maxToolCalls).toBe(20)
expect(result.stopWhen).toBe(20)
})
it('uses default value 20 when settings is empty object', () => {
const result = calculateEffectiveMaxToolCalls({})
expect(result.maxToolCalls).toBe(20)
expect(result.stopWhen).toBe(20)
})
it('uses default value 20 when maxToolCalls is undefined', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true
// maxToolCalls is undefined
})
expect(result.maxToolCalls).toBe(20)
expect(result.stopWhen).toBe(20)
})
it('uses custom value when maxToolCalls is set and enabled', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 50
})
expect(result.maxToolCalls).toBe(50)
expect(result.stopWhen).toBe(50)
})
})
describe('custom values when enabled', () => {
it('uses custom value when enableMaxToolCalls is true and maxToolCalls is set', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 50
})
expect(result.stopWhen).toBe(50)
})
it('uses custom value at minimum boundary (1)', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 1
})
expect(result.stopWhen).toBe(1)
})
it('uses custom value at maximum boundary (100)', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 100
})
expect(result.stopWhen).toBe(100)
})
it('clamps values above 100 to default when enabled', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 999
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('clamps zero to default when enabled', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 0
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('clamps negative values to default when enabled', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: -5
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
})
describe('disabled behavior', () => {
it('does not pass stopWhen when enableMaxToolCalls is false', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: false,
maxToolCalls: 50
})
// When disabled, stopWhen should be null (indicating no stopWhen passed)
expect(result.stopWhen).toBeNull()
})
it('does not pass stopWhen when both enableMaxToolCalls is false and maxToolCalls is undefined', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: false
})
expect(result.stopWhen).toBeNull()
})
it('falls back to default when disabled with invalid maxToolCalls', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: false,
maxToolCalls: 999
})
// When disabled, maxToolCalls should still be default (for reference)
expect(result.maxToolCalls).toBe(DEFAULT_MAX_TOOL_CALLS)
expect(result.stopWhen).toBeNull()
})
})
describe('backward compatibility', () => {
it('maintains backward compatibility - existing assistants without new fields use default', () => {
// Simulate an old assistant without the new fields
const oldSettings = {
// Old assistants don't have enableMaxToolCalls or maxToolCalls
temperature: 0.7,
contextCount: 10
}
const result = calculateEffectiveMaxToolCalls(
oldSettings as { maxToolCalls?: number; enableMaxToolCalls?: boolean }
)
// Should default to enabled with 20 for backward compatibility
expect(result.maxToolCalls).toBe(20)
expect(result.stopWhen).toBe(20)
})
})
describe('security - invalid values from imported/migrated settings', () => {
it('validates extremely large values from imported settings', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 999999
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('validates negative values from imported settings', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: -100
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
it('validates zero from imported settings', () => {
const result = calculateEffectiveMaxToolCalls({
enableMaxToolCalls: true,
maxToolCalls: 0
})
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
})
})
})

View File

@@ -1,283 +0,0 @@
/**
* 文件处理模块
* 处理文件内容提取、文件格式转换、文件上传等逻辑
*/
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { FileMetadata, Message, Model } from '@renderer/types'
import { FILE_TYPE } from '@renderer/types'
import type { FileMessageBlock } from '@renderer/types/newMessage'
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
import type { FilePart, TextPart } from 'ai'
import i18n from 'i18next'
import { getAiSdkProviderId } from '../provider/factory'
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload } from './modelCapabilities'
const logger = loggerService.withContext('fileProcessor')
/**
* 提取文件内容
*/
export async function extractFileContent(message: Message): Promise<string> {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => fb.file.type === type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
/**
* 将文件块转换为文本部分
*/
export async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<TextPart | null> {
const file = fileBlock.file
// 处理文本文件
if (file.type === FILE_TYPE.TEXT) {
try {
const fileContent = await window.api.file.read(file.id + file.ext)
return {
type: 'text',
text: `${file.origin_name}\n${fileContent.trim()}`
}
} catch (error) {
logger.warn('Failed to read text file:', error as Error)
}
}
// 处理文档文件PDF、Word、Excel等- 提取为文本内容
if (file.type === FILE_TYPE.DOCUMENT) {
try {
const fileContent = await window.api.file.read(file.id + file.ext, true) // true表示强制文本提取
return {
type: 'text',
text: `${file.origin_name}\n${fileContent.trim()}`
}
} catch (error) {
logger.warn(`Failed to extract text from document ${file.origin_name}:`, error as Error)
window.toast.error(i18n.t('message.error.file.text_extraction_failed', { name: file.origin_name }))
}
}
return null
}
/**
* 处理Gemini大文件上传
*/
export async function handleGeminiFileUpload(file: FileMetadata, model: Model): Promise<FilePart | null> {
try {
const provider = getProviderByModel(model)
// 检查文件是否已经上传过
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
const remoteFile = fileMetadata.originalFile.file as any // 临时类型断言因为File类型定义可能不完整
// 注意AI SDK的FilePart格式和Gemini原生格式不同这里需要适配
// 暂时返回null让它回退到文本处理或者需要扩展FilePart支持uri
logger.info(`File ${file.origin_name} already uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`)
return null
}
// 如果文件未上传,执行上传
const uploadResult = await window.api.fileService.upload(provider, file)
if (uploadResult.originalFile?.file) {
const remoteFile = uploadResult.originalFile.file as any // 临时类型断言
logger.info(`File ${file.origin_name} uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`)
// 同样这里需要处理URI格式的文件引用
return null
}
} catch (error) {
logger.error(`Failed to upload file ${file.origin_name} to Gemini:`, error as Error)
}
return null
}
/**
* 处理OpenAI兼容大文件上传
*/
export async function handleOpenAILargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
// 如果模型为qwen-long系列文档中要求purpose需要为'file-extract'
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
file = {
...file,
// 该类型并不在OpenAI定义中但符合sdk规范强制断言
purpose: 'file-extract' as OpenAI.FilePurpose
}
}
try {
// 检查文件是否已经上传过
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
// 判断用途是否一致
if (remoteFile.purpose !== file.purpose) {
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
throw new Error('File purpose mismatch')
}
return {
type: 'file',
filename: file.origin_name,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
return null
}
try {
// 如果文件未上传,执行上传
const uploadResult = await window.api.fileService.upload(provider, file)
if (uploadResult.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
logger.info(`File ${file.origin_name} uploaded.`)
return {
type: 'file',
filename: remoteFile.filename,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
}
return null
}
/**
* 大文件上传路由函数
*/
export async function handleLargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
if (['google', 'google-vertex'].includes(aiSdkId)) {
return await handleGeminiFileUpload(file, model)
}
if (provider.type === 'openai') {
return await handleOpenAILargeFileUpload(file, model)
}
return null
}
/**
* 将文件块转换为FilePart用于原生文件支持
*/
export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise<FilePart | null> {
const file = fileBlock.file
const fileSizeLimit = getFileSizeLimit(model, file.type)
try {
// 处理PDF文档始终生成 FilePart由下游插件处理兼容性
if (file.type === FILE_TYPE.DOCUMENT && file.ext === '.pdf') {
// 检查文件大小限制
if (file.size > fileSizeLimit) {
// 如果支持大文件上传如Gemini File API尝试上传
if (supportsLargeFileUpload(model)) {
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
const uploadResult = await handleLargeFileUpload(file, model)
if (uploadResult) {
return uploadResult
}
// 如果上传失败,回退到文本处理
logger.warn(`Failed to upload large PDF ${file.origin_name}, falling back to text extraction`)
window.toast.warning(i18n.t('message.warning.file.pdf_upload_failed', { name: file.origin_name }))
return null
} else {
logger.warn(`PDF file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`)
window.toast.warning(
i18n.t('message.warning.file.pdf_exceeds_limit', {
name: file.origin_name,
limit: `${Math.round(fileSizeLimit / 1024 / 1024)}MB`
})
)
return null // 文件过大,回退到文本处理
}
}
const base64Data = await window.api.file.base64File(file.id + file.ext)
return {
type: 'file',
data: base64Data.data,
mediaType: base64Data.mime,
filename: file.origin_name
}
}
// 处理图片文件
if (file.type === FILE_TYPE.IMAGE && supportsImageInput(model)) {
// 检查文件大小
if (file.size > fileSizeLimit) {
logger.warn(`Image file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`)
return null
}
const base64Data = await window.api.file.base64Image(file.id + file.ext)
// 处理MIME类型特别是jpg->jpeg的转换Anthropic要求
let mediaType = base64Data.mime
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
if (aiSdkId === 'anthropic' && mediaType === 'image/jpg') {
mediaType = 'image/jpeg'
}
return {
type: 'file',
data: base64Data.base64,
mediaType: mediaType,
filename: file.origin_name
}
}
// 处理其他文档类型Word、Excel等
if (file.type === FILE_TYPE.DOCUMENT && file.ext !== '.pdf') {
// 目前大多数提供商不支持Word等格式的原生处理
// 返回null会触发上层调用convertFileBlockToTextPart进行文本提取
// 这与Legacy架构中的处理方式一致
logger.debug(`Document file ${file.origin_name} with extension ${file.ext} will use text extraction fallback`)
return null
}
} catch (error) {
logger.warn(`Failed to process file ${file.origin_name}:`, error as Error)
}
return null
}

View File

@@ -1,33 +0,0 @@
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
// const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
const anthropicHeaders: string[] = []
const provider = getProviderByModel(model)
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) || isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
}
if (isClaude4SeriesModel(model)) {
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
// We may add it by user preference in assistant.settings instead of always adding it.
// See #11540, #11397
// anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders
}

View File

@@ -1,22 +0,0 @@
/**
* AI SDK 参数转换模块 - 统一入口
*
* 此模块已重构,功能分拆到以下子模块:
* - modelParameters.ts: 基础参数处理 (温度、TopP、超时)
* - modelCapabilities.ts: 模型能力检查 (PDF、图片、文件支持)
* - fileProcessor.ts: 文件处理逻辑 (转换、上传)
* - messageConverter.ts: 消息转换核心 (单个消息转换)
* - parameterBuilder.ts: 参数构建器 (最终参数组装)
*/
// 基础参数处理
export { getTimeout } from './modelParameters'
// 文件处理
export { extractFileContent } from './fileProcessor'
// 消息转换
export { convertMessagesToSdkMessages, convertMessageToSdkParam } from './messageConverter'
// 参数构建 (主要API)
export { buildGenerateTextParams, buildStreamTextParams } from './parameterBuilder'

View File

@@ -1,387 +0,0 @@
/**
* 消息转换模块
* 将 Cherry Studio 消息格式转换为 AI SDK 消息格式
*/
import type { ReasoningPart } from '@ai-sdk/provider-utils'
import { loggerService } from '@logger'
import { isVisionModel } from '@renderer/config/models'
import type { Message, Model } from '@renderer/types'
import type {
FileMessageBlock,
ImageMessageBlock,
MainTextMessageBlock,
ThinkingMessageBlock
} from '@renderer/types/newMessage'
import {
findFileBlocks,
findImageBlocks,
findMainTextBlocks,
findThinkingBlocks,
getMainTextContent
} from '@renderer/utils/messageUtils/find'
import { parseDataUrl } from '@shared/utils'
import type {
AssistantModelMessage,
FilePart,
ImagePart,
ModelMessage,
SystemModelMessage,
TextPart,
UserModelMessage
} from 'ai'
import i18n from 'i18next'
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
const logger = loggerService.withContext('messageConverter')
/**
* 转换消息为 AI SDK 参数格式
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
*/
export async function convertMessageToSdkParam(
message: Message,
isVisionModel = false,
model?: Model
): Promise<ModelMessage | ModelMessage[]> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
const reasoningBlocks = findThinkingBlocks(message)
const mainTextBlocks = findMainTextBlocks(message)
if (message.role === 'user' || message.role === 'system') {
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model)
} else {
return convertMessageToAssistantModelMessage(
content,
fileBlocks,
imageBlocks,
reasoningBlocks,
mainTextBlocks,
model
)
}
}
async function convertImageBlockToImagePart(imageBlocks: ImageMessageBlock[]): Promise<Array<ImagePart>> {
const parts: Array<ImagePart> = []
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
try {
const ext = imageBlock.file.ext.startsWith('.') ? imageBlock.file.ext : `.${imageBlock.file.ext}`
const image = await window.api.file.base64Image(imageBlock.file.id + ext)
parts.push({
type: 'image',
image: image.base64,
mediaType: image.mime
})
} catch (error) {
logger.error('Failed to load image file, image will be excluded from message:', {
fileId: imageBlock.file.id,
fileName: imageBlock.file.origin_name,
error: error as Error
})
}
} else if (imageBlock.url) {
const url = imageBlock.url
const parseResult = parseDataUrl(url)
if (parseResult?.isBase64) {
const { mediaType, data } = parseResult
parts.push({ type: 'image', image: data, ...(mediaType ? { mediaType } : {}) })
} else if (url.startsWith('data:')) {
// Malformed data URL or non-base64 data URL
logger.error('Malformed or non-base64 data URL detected, image will be excluded:', {
urlPrefix: url.slice(0, 50) + '...'
})
continue
} else {
// For remote URLs we keep payload minimal to match existing expectations.
parts.push({ type: 'image', image: url })
}
}
}
return parts
}
/**
* 转换为用户模型消息
*/
async function convertMessageToUserModelMessage(
content: string,
fileBlocks: FileMessageBlock[],
imageBlocks: ImageMessageBlock[],
isVisionModel = false,
model?: Model
): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
const parts: Array<TextPart | FilePart | ImagePart> = []
if (content) {
parts.push({ type: 'text', text: content })
}
// 处理图片(仅在支持视觉的模型中)
if (isVisionModel) {
parts.push(...(await convertImageBlockToImagePart(imageBlocks)))
}
// 处理文件
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
let processed = false
// 优先尝试原生文件支持PDF、图片等
if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) {
// 判断filePart是否为string
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
return [
{
role: 'system',
content: filePart.data
},
{
role: 'user',
content: parts.length > 0 ? parts : ''
}
]
}
parts.push(filePart)
logger.debug(`File ${file.origin_name} processed as native file format`)
processed = true
}
}
// 如果原生处理失败,回退到文本提取
if (!processed) {
const textPart = await convertFileBlockToTextPart(fileBlock)
if (textPart) {
parts.push(textPart)
logger.debug(`File ${file.origin_name} processed as text content`)
} else {
logger.warn(`File ${file.origin_name} could not be processed in any format`)
window.toast.error(i18n.t('message.error.file.process_failed', { name: file.origin_name }))
}
}
}
return {
role: 'user',
content: parts
}
}
/**
* Replaces markdown images with data URI sources (e.g. `![alt](data:image/...;base64,...)`)
* with a placeholder `![alt](image)` to avoid sending huge base64 payloads to the API.
*
* Uses string scanning (indexOf) instead of regex to avoid OOM on multi-MB base64 strings.
*/
export function stripMarkdownBase64Images(text: string): string {
const marker = '](data:'
let result = ''
let searchFrom = 0
while (searchFrom < text.length) {
const markerIdx = text.indexOf(marker, searchFrom)
if (markerIdx === -1) {
result += text.slice(searchFrom)
break
}
// Find the `![` that starts this markdown image — walk backwards from `](`
const bangIdx = text.lastIndexOf('![', markerIdx)
if (bangIdx === -1 || text.indexOf(']', bangIdx + 2) !== markerIdx) {
// Not a valid markdown image — skip past this marker
result += text.slice(searchFrom, markerIdx + marker.length)
searchFrom = markerIdx + marker.length
continue
}
// Find the closing `)` — the URL part starts after `](`
const urlStart = markerIdx + 2 // position right after `](`
const closeIdx = text.indexOf(')', urlStart)
if (closeIdx === -1) {
result += text.slice(searchFrom)
break
}
// Extract alt text between `![` and `]`
const altText = text.slice(bangIdx + 2, markerIdx)
// Append everything before `![` plus the replacement
result += text.slice(searchFrom, bangIdx) + `![${altText}](image)`
searchFrom = closeIdx + 1
}
return result
}
/**
* 转换为助手模型消息
* 注意:当助手消息只包含图片(如图片生成模型的响应)而没有文本时,
* 需要添加占位文本,因为某些 API如 Gemini不接受空的 assistant 消息
*/
async function convertMessageToAssistantModelMessage(
content: string,
fileBlocks: FileMessageBlock[],
imageBlocks: ImageMessageBlock[],
thinkingBlocks: ThinkingMessageBlock[],
mainTextBlocks: MainTextMessageBlock[],
model?: Model
): Promise<AssistantModelMessage> {
const parts: Array<TextPart | ReasoningPart | FilePart> = []
// Add reasoning blocks first (required by AWS Bedrock for Claude extended thinking)
for (const thinkingBlock of thinkingBlocks) {
parts.push({ type: 'reasoning', text: thinkingBlock.content })
}
// Add text content after reasoning blocks, only if non-empty after trimming
// Also add thoughtSignature from MainTextBlock metadata for Gemini thought signature persistence
// Strip inline base64 data URIs from markdown images to prevent HTTP 413 errors (#12602)
// Uses string scanning instead of regex to avoid OOM on large base64 payloads
const trimmedContent = stripMarkdownBase64Images(content?.trim() ?? '')
if (trimmedContent) {
// Find the first MainTextBlock with thoughtSignature
const thoughtSignature = mainTextBlocks.find((block) => block.metadata?.thoughtSignature)?.metadata
?.thoughtSignature
const textPart: TextPart = { type: 'text', text: trimmedContent }
// Add providerOptions with thoughtSignature if available (for Gemini)
if (thoughtSignature) {
textPart.providerOptions = {
google: {
thoughtSignature
}
}
}
parts.push(textPart)
}
for (const fileBlock of fileBlocks) {
// 优先尝试原生文件支持PDF等
if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) {
parts.push(filePart)
continue
}
}
// 回退到文本处理
const textPart = await convertFileBlockToTextPart(fileBlock)
if (textPart) {
parts.push(textPart)
}
}
// 当 parts 为空但有图片时,添加占位文本
// 这对于图片生成模型的继续对话很重要,因为助手消息可能只包含生成的图片
if (parts.length === 0 && imageBlocks.length > 0) {
parts.push({ type: 'text', text: '[Image]' })
}
return {
role: 'assistant',
content: parts
}
}
/**
* Converts an array of messages to SDK-compatible model messages.
*
* This function processes messages and transforms them into the format required by the SDK.
* It handles special cases for vision models and image enhancement models.
*
* @param messages - Array of messages to convert.
* @param model - The model configuration that determines conversion behavior
*
* @returns A promise that resolves to an array of SDK-compatible model messages
*
* @remarks
* For image enhancement models:
* - Collapses the conversation into [system?, user(image)] format
* - Searches backwards through all messages to find the most recent assistant message with images
* - Preserves all system messages (including ones generated from file uploads like 'fileid://...')
* - Extracts the last user message content and merges images from the previous assistant message
* - Returns only the collapsed messages: system messages (if any) followed by a single user message
* - If no user message is found, returns only system messages
* - Typical pattern: [system?, user, assistant(image), user] -> [system?, user(image)]
*
* For other models:
* - Returns all converted messages in order without special image handling
*
* The function automatically detects vision model capabilities and adjusts conversion accordingly.
*/
export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise<ModelMessage[]> {
const sdkMessages: ModelMessage[] = []
const isVision = isVisionModel(model)
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
}
// Special handling for vison models
// These models support multi-turn conversations but need images from previous assistant messages
// to be merged into the current user message for editing/enhancement operations.
//
// Key behaviors:
// 1. Preserve all conversation history for context
// 2. Find images from the previous assistant message and merge them into the last user message
// 3. This allows users to switch from LLM conversations and use that context for image generation
if (isVision) {
// Find the last user SDK message index
const lastUserSdkIndex = (() => {
for (let i = sdkMessages.length - 1; i >= 0; i--) {
if (sdkMessages[i].role === 'user') return i
}
return -1
})()
// If no user message found, return messages as-is
if (lastUserSdkIndex < 0) {
return sdkMessages
}
// Find the nearest preceding assistant message in original messages
let prevAssistant: Message | null = null
for (let i = messages.length - 2; i >= 0; i--) {
if (messages[i].role === 'assistant') {
prevAssistant = messages[i]
break
}
}
// Check if there are images from the previous assistant message
const imageBlocks = prevAssistant ? findImageBlocks(prevAssistant) : []
const imageParts = await convertImageBlockToImagePart(imageBlocks)
// If no images to merge, return messages as-is
if (imageParts.length === 0) {
return sdkMessages
}
// Build the new last user message with merged images
const lastUserSdk = sdkMessages[lastUserSdkIndex] as UserModelMessage
let finalUserParts: Array<TextPart | FilePart | ImagePart> = []
if (typeof lastUserSdk.content === 'string') {
finalUserParts.push({ type: 'text', text: lastUserSdk.content })
} else if (Array.isArray(lastUserSdk.content)) {
finalUserParts = [...lastUserSdk.content]
}
// Append images from the previous assistant message
finalUserParts.push(...imageParts)
// Replace the last user message with the merged version
const result = [...sdkMessages]
result[lastUserSdkIndex] = { role: 'user', content: finalUserParts }
return result
}
return sdkMessages
}

View File

@@ -1,93 +0,0 @@
/**
* 模型能力检查模块
* 检查不同模型支持的功能PDF输入、图片输入、大文件上传等
*/
import { isVisionModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { FileType, Model } from '@renderer/types'
import { FILE_TYPE } from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory'
// 工具函数:基于模型名和提供商判断是否支持某特性
function modelSupportValidator(
model: Model,
{
supportedModels = [],
unsupportedModels = [],
supportedProviders = [],
unsupportedProviders = []
}: {
supportedModels?: string[]
unsupportedModels?: string[]
supportedProviders?: string[]
unsupportedProviders?: string[]
}
): boolean {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
// 黑名单:命中不支持的模型直接拒绝
if (unsupportedModels.some((name) => model.name.includes(name))) {
return false
}
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
if (unsupportedProviders.includes(aiSdkId)) {
return false
}
// 白名单:命中支持的模型名
if (supportedModels.some((name) => model.name.includes(name))) {
return true
}
// 回退到提供商判断
return supportedProviders.includes(aiSdkId)
}
/**
* 检查模型是否支持原生图片输入
*/
export function supportsImageInput(model: Model): boolean {
return isVisionModel(model)
}
/**
* 检查提供商是否支持大文件上传如Gemini File API
*/
export function supportsLargeFileUpload(model: Model): boolean {
// 基于AI SDK文档以下模型或提供商支持大文件上传
return modelSupportValidator(model, {
supportedModels: ['qwen-long', 'qwen-doc'],
supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
})
}
/**
* 获取提供商特定的文件大小限制
*/
export function getFileSizeLimit(model: Model, fileType: FileType): number {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
// Anthropic PDF限制32MB
if (aiSdkId === 'anthropic' && fileType === FILE_TYPE.DOCUMENT) {
return 32 * 1024 * 1024 // 32MB
}
// Gemini小文件限制20MB超过此限制会使用File API上传
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
return 20 * 1024 * 1024 // 20MB
}
// Dashscope如果模型支持大文件上传优先使用File API上传
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
return 0 // 使用较小的默认值
}
// 其他提供商没有明确限制,使用较大的默认值
// 这与Legacy架构中的实现一致让提供商自行处理文件大小
return Infinity
}

View File

@@ -1,156 +0,0 @@
/**
* 模型基础参数处理模块
* 处理温度、TopP、超时等基础参数的获取逻辑
*/
import { loggerService } from '@logger'
import {
isClaude46SeriesModel,
isClaudeReasoningModel,
isMaxTemperatureOneModel,
isSupportedFlexServiceTier,
isSupportedThinkingTokenClaudeModel,
isSupportTemperatureModel,
isSupportTopPModel,
isTemperatureTopPMutuallyExclusiveModel
} from '@renderer/config/models'
import {
DEFAULT_ASSISTANT_SETTINGS,
getAssistantSettings,
getProviderByModel
} from '@renderer/services/AssistantService'
import { type Assistant, type Model } from '@renderer/types'
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
import { getThinkingBudget } from '../utils/reasoning'
const logger = loggerService.withContext('modelParameters')
/**
* Retrieves the temperature parameter, adapting it based on assistant.settings and model capabilities.
* - Disabled when enableTemperature is off.
* - Disabled for Claude reasoning models when reasoning effort is set (excluding 'default' and 'none').
* - Disabled for models that do not support temperature.
* - Clamped to 1 for models with max temperature of 1.
* Otherwise, returns the temperature value.
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
const enableTemperature = assistant.settings?.enableTemperature ?? DEFAULT_ASSISTANT_SETTINGS.enableTemperature
if (!enableTemperature) {
return undefined
}
// Thinking isn't compatible with temperature or top_k modifications as well as forced tool use.
// See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking#feature-compatibility
if (
isClaudeReasoningModel(model) &&
assistant.settings?.reasoning_effort &&
assistant.settings.reasoning_effort !== 'default' &&
assistant.settings.reasoning_effort !== 'none'
) {
logger.info(`Model ${model.id} does not support reasoning with temperature, disabling temperature`)
return undefined
}
if (!isSupportTemperatureModel(model, assistant)) {
logger.info(`Model ${model.id} does not support temperature, disabling temperature`)
return undefined
}
let temperature = assistant.settings?.temperature ?? DEFAULT_ASSISTANT_SETTINGS.temperature
if (isMaxTemperatureOneModel(model) && temperature > 1) {
logger.info(`Model ${model.id} has max temperature of 1, clamping temperature from ${temperature} to 1`)
temperature = 1
}
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTopP) {
logger.info(`Model ${model.id} only accepts one of temperature and topP, both enabled; keeping temperature`)
}
return temperature
}
/**
* Retrieves the TopP parameter, adapting it based on assistant.settings and model capabilities.
* - Disabled when enableTopP is off.
* - Disabled for models that do not support TopP.
* - Disabled for mutually exclusive models when temperature is enabled.
* - Clamped to [0.95, 1] for Claude reasoning models with reasoning effort set (excluding 'default' and 'none').
* Otherwise, returns the TopP value.
*/
export function getTopP(assistant: Assistant, model: Model): number | undefined {
const enableTopP = assistant.settings?.enableTopP ?? DEFAULT_ASSISTANT_SETTINGS.enableTopP
if (!enableTopP) {
return undefined
}
if (!isSupportTopPModel(model, assistant)) {
logger.info(`Model ${model.id} does not support topP, disabling topP.`)
return undefined
}
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTemperature) {
logger.info(`Model ${model.id} only accepts one of temperature and topP, disabling topP.`)
return undefined
}
let topP = assistant.settings?.topP ?? DEFAULT_ASSISTANT_SETTINGS.topP
// When thinking is enabled, the topP should be between 0.95 and 1
// See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking#feature-compatibility
// NOTE: It depends on the behavior that extended thinking defaults to off, so we clamp the topP value also when reasoning is not 'default'
if (
isClaudeReasoningModel(model) &&
assistant.settings?.reasoning_effort &&
assistant.settings.reasoning_effort !== 'default' &&
assistant.settings.reasoning_effort !== 'none'
) {
const clampedTopP = Math.max(0.95, Math.min(topP, 1))
if (clampedTopP !== topP) {
logger.info(`Claude Model ${model.id} has reasoning enabled, clamping topP from ${topP} to ${clampedTopP}`)
}
topP = clampedTopP
}
return topP
}
/**
* 获取超时设置
*/
export function getTimeout(model: Model): number {
if (isSupportedFlexServiceTier(model)) {
return 15 * 1000 * 60
}
return DEFAULT_TIMEOUT
}
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
// NOTE: ai-sdk会把maxToken和budgetToken加起来
const assistantSettings = getAssistantSettings(assistant)
const enabledMaxTokens = assistantSettings.enableMaxTokens ?? false
let maxTokens = assistantSettings.maxTokens
// If user hasn't enabled enableMaxTokens, return undefined to let the API use its default value.
// Note: Anthropic API requires max_tokens, but that's handled by the Anthropic client with a fallback.
if (!enabledMaxTokens || maxTokens === undefined) {
return undefined
}
const provider = getProviderByModel(model)
// Claude 4.6 uses adaptive thinking (no budgetTokens), so the AI SDK does not add budget back
// to maxOutputTokens. Skip the subtraction to avoid incorrectly reducing max_tokens.
if (
isSupportedThinkingTokenClaudeModel(model) &&
!isClaude46SeriesModel(model) &&
['anthropic', 'aws-bedrock'].includes(provider.type)
) {
const { reasoning_effort: reasoningEffort } = assistantSettings
const budget = getThinkingBudget(maxTokens, reasoningEffort, model.id)
if (budget) {
maxTokens -= budget
}
}
return maxTokens
}

View File

@@ -1,263 +0,0 @@
/**
* 参数构建模块
* 构建AI SDK的流式和非流式参数
*/
import { combineHeaders } from '@ai-sdk/provider-utils'
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { extensionRegistry } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { AppProviderId } from '@renderer/aiCore/types'
import { MAX_TOOL_CALLS, MIN_TOOL_CALLS } from '@renderer/config/constant'
import {
isAnthropicModel,
isFixedReasoningModel,
isGeminiModel,
isGenerateImageModel,
isGrokModel,
isOpenAIModel,
isOpenRouterBuiltInWebSearchModel,
isPureGenerateImageModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getHubModeSystemPrompt } from '@renderer/config/prompts-code-mode'
import { DEFAULT_ASSISTANT_SETTINGS, getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { IdleTimeoutController, type IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
import { replacePromptVariables } from '@renderer/utils/prompt'
import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
import type { ModelMessage } from 'ai'
import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import type { ProviderCapabilities } from '../types'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
import { addAnthropicHeaders } from './header'
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
/**
* Validates and clamps maxToolCalls to valid range
* Falls back to DEFAULT_ASSISTANT_SETTINGS.maxToolCalls if invalid
* @param value - The maxToolCalls value from settings
* @returns Validated maxToolCalls value
*/
function validateMaxToolCalls(value: number | undefined): number {
if (value === undefined || value < MIN_TOOL_CALLS || value > MAX_TOOL_CALLS) {
return DEFAULT_ASSISTANT_SETTINGS.maxToolCalls
}
return value
}
function mapVertexAIGatewayModelToProviderId(model: Model): AppProviderId | undefined {
if (isAnthropicModel(model)) {
return 'anthropic'
}
if (isGeminiModel(model)) {
return 'google'
}
if (isGrokModel(model)) {
return 'xai'
}
if (isOpenAIModel(model)) {
return 'openai'
}
logger.warn(`Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`)
return undefined
}
/**
* 构建 AI SDK 流式参数
* 这是主要的参数构建函数,整合所有转换逻辑
*/
export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'] = [],
assistant: Assistant,
provider: Provider,
options: {
mcpTools?: MCPTool[]
allowedTools?: string[]
webSearchProviderId?: string
webSearchConfig?: CherryWebSearchConfig
requestOptions?: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string | undefined>
}
}
): Promise<{
params: StreamTextParams
modelId: string
capabilities: ProviderCapabilities
webSearchPluginConfig?: WebSearchPluginConfig
idleTimeout: IdleTimeoutHandle
}> {
const { mcpTools, requestOptions = {} } = options
// No caller currently provides a custom timeout; defaultTimeout (10 min) is the fallback.
const { signal: externalSignal, timeout = DEFAULT_TIMEOUT, headers: inputHeaders = {} } = requestOptions
// Use an idle timeout that resets every time a stream chunk is received,
// instead of a fixed total timeout that starts from the initial request.
const idleTimeout = new IdleTimeoutController(timeout)
const signals = [idleTimeout.signal]
if (externalSignal) {
signals.push(externalSignal)
}
const finalSignal = AbortSignal.any(signals)
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
isFixedReasoningModel(model)
// 判断是否使用内置搜索
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
const hasExternalSearch = !!options.webSearchProviderId
const enableWebSearch =
!hasExternalSearch &&
((assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar'))
// Validate provider and model support to prevent stale state from triggering urlContext
const enableUrlContext = !!(
assistant.enableUrlContext &&
isSupportUrlContextProvider(provider) &&
!isPureGenerateImageModel(model) &&
(isGeminiModel(model) || isAnthropicModel(model))
)
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
const tools = setupToolsConfig(mcpTools, options.allowedTools)
// 构建真正的 providerOptions
const webSearchConfig: CherryWebSearchConfig = {
maxResults: store.getState().websearch.maxResults,
excludeDomains: store.getState().websearch.excludeDomains,
searchWithTime: store.getState().websearch.searchWithTime
}
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
enableReasoning,
enableWebSearch,
enableGenerateImage
})
// Web search + URL context 的工具注入由 plugin 系统处理:
// - webSearchPlugin: 根据 provider 的 toolFactories.webSearch 自动注入
// - urlContextPlugin: 根据 provider 的 toolFactories.urlContext 自动注入
// parameterBuilder 只构建 config传给 plugin
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) {
if (extensionRegistry.has(aiSdkProviderId)) {
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
} else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) {
const gatewayProviderId = mapVertexAIGatewayModelToProviderId(model)
if (gatewayProviderId) {
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(gatewayProviderId, webSearchConfig, model)
}
}
}
let headers = inputHeaders
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
const betaHeaders = addAnthropicHeaders(assistant, model)
// Only add the anthropic-beta header if there are actual beta headers to include
if (betaHeaders.length > 0) {
const newBetaHeaders = { 'anthropic-beta': betaHeaders.join(',') }
headers = combineHeaders(headers, newBetaHeaders)
}
}
// 构建基础参数
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
// are extracted from custom parameters and passed directly to streamText()
// instead of being placed in providerOptions
// Get max tool calls from assistant settings
// When enabled, validate and use user-defined value (1-100)
// When disabled, don't pass stopWhen - let AI SDK use its own default
const enableMaxToolCalls = assistant.settings?.enableMaxToolCalls ?? DEFAULT_ASSISTANT_SETTINGS.enableMaxToolCalls
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
// Include AI SDK standard params extracted from custom parameters
...standardParams,
abortSignal: finalSignal,
headers,
providerOptions,
maxRetries: 0
}
// Only add stopWhen when explicitly enabled and validated
if (enableMaxToolCalls) {
const maxToolCalls = validateMaxToolCalls(assistant.settings?.maxToolCalls)
params.stopWhen = stepCountIs(maxToolCalls)
}
// When disabled, don't pass stopWhen - let AI SDK use its own default
if (tools) {
params.tools = tools
}
let systemPrompt = assistant.prompt ? await replacePromptVariables(assistant.prompt, model.name) : ''
if (getEffectiveMcpMode(assistant) === 'auto') {
const autoModePrompt = getHubModeSystemPrompt()
if (autoModePrompt) {
systemPrompt = systemPrompt ? `${systemPrompt}\n\n${autoModePrompt}` : autoModePrompt
}
}
if (systemPrompt) {
params.system = systemPrompt
}
logger.debug('params', params)
return {
params,
modelId: model.id,
capabilities: { enableReasoning, enableWebSearch, enableGenerateImage, enableUrlContext },
webSearchPluginConfig,
idleTimeout
}
}
/**
* 构建非流式的 generateText 参数
*/
export async function buildGenerateTextParams(
messages: ModelMessage[],
assistant: Assistant,
provider: Provider,
options: {
mcpTools?: MCPTool[]
allowedTools?: string[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, provider, options)
}

View File

@@ -1,153 +0,0 @@
import type { Model, Provider } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { getAiSdkProviderId } from '../factory'
// Mock the external dependencies
vi.mock('@cherrystudio/ai-core', () => ({
registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers
getProviderMapping: vi.fn((id: string) => {
// Mock dynamic mappings
const mappings: Record<string, string> = {
openrouter: 'openrouter',
'google-vertex': 'google-vertex',
vertexai: 'google-vertex',
bedrock: 'bedrock',
'aws-bedrock': 'bedrock',
zhipu: 'zhipu'
}
return mappings[id]
}),
AiCore: {
isSupported: vi.fn(() => true)
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
info: vi.fn(),
warn: vi.fn(),
error: vi.fn()
})
}
}))
function createTestProvider(id: string, type: string): Provider {
return {
id,
type,
name: `Test ${id}`,
apiKey: 'test-key',
apiHost: 'test-host'
} as Provider
}
function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
return {
id,
type: 'azure-openai',
name: `Azure Test ${id}`,
apiKey: 'azure-test-key',
apiHost: 'azure-test-host',
apiVersion,
models: [{ id: model || 'gpt-4' } as Model]
}
}
describe('Integrated Provider Registry', () => {
describe('Provider ID Resolution', () => {
it('should resolve openrouter provider correctly', () => {
const provider = createTestProvider('openrouter', 'openrouter')
const result = getAiSdkProviderId(provider)
expect(result).toBe('openrouter')
})
it('should resolve google-vertex provider correctly', () => {
const provider = createTestProvider('google-vertex', 'vertexai')
const result = getAiSdkProviderId(provider)
expect(result).toBe('google-vertex')
})
it('should resolve bedrock provider correctly', () => {
const provider = createTestProvider('bedrock', 'aws-bedrock')
const result = getAiSdkProviderId(provider)
expect(result).toBe('bedrock')
})
it('should resolve zhipu provider correctly', () => {
const provider = createTestProvider('zhipu', 'zhipu')
const result = getAiSdkProviderId(provider)
expect(result).toBe('zhipu')
})
it('should resolve provider type mapping correctly', () => {
const provider = createTestProvider('vertex-test', 'vertexai')
const result = getAiSdkProviderId(provider)
expect(result).toBe('google-vertex')
})
it('should handle static provider mappings', () => {
const geminiProvider = createTestProvider('gemini', 'gemini')
const result = getAiSdkProviderId(geminiProvider)
expect(result).toBe('google')
})
it('should fallback to provider.id for unknown providers', () => {
const unknownProvider = createTestProvider('unknown-provider', 'unknown-type')
const result = getAiSdkProviderId(unknownProvider)
expect(result).toBe('unknown-provider')
})
it('should handle Azure OpenAI providers correctly', () => {
const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
const result = getAiSdkProviderId(azureProvider)
expect(result).toBe('azure')
})
it('should handle Azure OpenAI providers response endpoint correctly', () => {
const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
const result = getAiSdkProviderId(azureProvider)
expect(result).toBe('azure-responses')
})
it('should handle Azure provider Claude Models', () => {
const provider = createTestProvider('azure-anthropic', 'anthropic')
const result = getAiSdkProviderId(provider)
expect(result).toBe('azure-anthropic')
})
})
describe('Backward Compatibility', () => {
it('should maintain compatibility with existing providers', () => {
const grokProvider = createTestProvider('grok', 'grok')
const result = getAiSdkProviderId(grokProvider)
expect(result).toBe('xai-responses')
})
})
})

View File

@@ -1,14 +0,0 @@
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const

View File

@@ -1,157 +0,0 @@
/**
* AiHubMix Provider
*
* Multi-backend API gateway that routes models by model ID prefix:
* - claude* -> Anthropic SDK
* - gemini* -> Google SDK
* - others -> OpenAI Responses SDK (default)
*
* All requests include the APP-Code header.
*/
import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal'
import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'
import { OpenAIChatLanguageModel, OpenAIResponsesLanguageModel, OpenAISpeechModel } from '@ai-sdk/openai/internal'
import {
OpenAICompatibleChatLanguageModel,
OpenAICompatibleEmbeddingModel,
OpenAICompatibleImageModel
} from '@ai-sdk/openai-compatible'
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { FetchFunction } from '@ai-sdk/provider-utils'
import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'
import { isOpenAIChatCompletionOnlyModel, isOpenAILLMModel } from '@renderer/config/models/openai'
import type { Model } from '@renderer/types'
export const AIHUBMIX_PROVIDER_NAME = 'aihubmix' as const
const APP_CODE_HEADER = { 'APP-Code': 'MLTG2087' }
export interface AihubmixProviderSettings {
apiKey?: string
baseURL?: string
headers?: Record<string, string>
fetch?: FetchFunction
}
export interface AihubmixProvider extends ProviderV3 {
(modelId: string): LanguageModelV3
languageModel(modelId: string): LanguageModelV3
embeddingModel(modelId: string): EmbeddingModelV3
imageModel(modelId: string): ImageModelV3
}
export function createAihubmix(options: AihubmixProviderSettings = {}): AihubmixProvider {
const { baseURL = 'https://aihubmix.com/v1', fetch: customFetch } = options
const resolveApiKey = () =>
loadApiKey({ apiKey: options.apiKey, environmentVariableName: 'AIHUBMIX_API_KEY', description: 'AiHubMix' })
const authHeaders = (): Record<string, string> => ({
Authorization: `Bearer ${resolveApiKey()}`,
'Content-Type': 'application/json',
...APP_CODE_HEADER,
...options.headers
})
const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}`
const createAnthropicModel = (modelId: string) => {
const headers = authHeaders()
return new AnthropicMessagesLanguageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.anthropic`,
baseURL,
headers: () => ({ ...headers, 'x-api-key': resolveApiKey() }),
fetch: customFetch,
supportedUrls: () => ({ 'image/*': [/^https?:\/\/.*$/] })
})
}
const createGeminiModel = (modelId: string) => {
const headers = authHeaders()
return new GoogleGenerativeAILanguageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.google`,
baseURL: 'https://aihubmix.com/gemini/v1beta',
headers: () => ({ ...headers, 'x-goog-api-key': resolveApiKey() }),
fetch: customFetch,
generateId: () => `${AIHUBMIX_PROVIDER_NAME}-${Date.now()}`,
supportedUrls: () => ({})
})
}
const createOpenAICompatibleChatModel = (modelId: string): LanguageModelV3 =>
new OpenAICompatibleChatLanguageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-compatible-chat`,
url,
headers: authHeaders,
fetch: customFetch
})
const createOpenAIChatModel = (modelId: string): LanguageModelV3 =>
new OpenAIChatLanguageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-compatible-chat`,
url,
headers: authHeaders,
fetch: customFetch
})
const createResponsesModel = (modelId: string): LanguageModelV3 =>
new OpenAIResponsesLanguageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-response`,
url,
headers: authHeaders,
fetch: customFetch,
fileIdPrefixes: ['file-']
})
const createChatModel = (modelId: string): LanguageModelV3 => {
if (modelId.startsWith('claude')) {
return createAnthropicModel(modelId)
}
if (
(modelId.startsWith('gemini') || modelId.startsWith('imagen')) &&
!modelId.endsWith('no-think') &&
!modelId.endsWith('-search') &&
!modelId.includes('embedding')
) {
return createGeminiModel(modelId)
}
const model = { id: modelId } as Model
if (isOpenAILLMModel(model)) {
if (isOpenAIChatCompletionOnlyModel(model)) {
return createOpenAIChatModel(modelId)
}
return createResponsesModel(modelId)
}
return createOpenAICompatibleChatModel(modelId)
}
const provider = (modelId: string) => createChatModel(modelId)
provider.specificationVersion = 'v3' as const
provider.languageModel = createChatModel
provider.embeddingModel = (modelId: string) =>
new OpenAICompatibleEmbeddingModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.embedding`,
url,
headers: authHeaders,
fetch: customFetch
})
provider.imageModel = (modelId: string) =>
new OpenAICompatibleImageModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.image`,
url,
headers: authHeaders,
fetch: customFetch
})
provider.speechModel = (modelId: string) =>
new OpenAISpeechModel(modelId, {
provider: `${AIHUBMIX_PROVIDER_NAME}.speech`,
url,
headers: authHeaders,
fetch: customFetch
})
return provider as AihubmixProvider
}

View File

@@ -1,141 +0,0 @@
/**
* NewAPI Provider
*
* Multi-backend API gateway (One API / New API) that routes models by endpoint_type:
* - anthropic -> Anthropic SDK
* - gemini -> Google SDK
* - openai-response -> OpenAI Responses SDK
* - openai / image-generation -> OpenAI Chat SDK
* - fallback -> OpenAI Compatible SDK
*
* The endpointType is set per-request via provider settings, based on the model's endpoint_type field.
*/
import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal'
import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'
import { OpenAIResponsesLanguageModel } from '@ai-sdk/openai/internal'
import {
OpenAICompatibleChatLanguageModel,
OpenAICompatibleEmbeddingModel,
OpenAICompatibleImageModel
} from '@ai-sdk/openai-compatible'
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { FetchFunction } from '@ai-sdk/provider-utils'
import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'
export const NEWAPI_PROVIDER_NAME = 'newapi' as const
export type NewApiEndpointType =
| 'openai'
| 'openai-response'
| 'anthropic'
| 'gemini'
| 'image-generation'
| 'jina-rerank'
export interface NewApiProviderSettings {
apiKey?: string
baseURL?: string
headers?: Record<string, string>
fetch?: FetchFunction
endpointType?: NewApiEndpointType
}
export interface NewApiProvider extends ProviderV3 {
(modelId: string): LanguageModelV3
languageModel(modelId: string): LanguageModelV3
embeddingModel(modelId: string): EmbeddingModelV3
imageModel(modelId: string): ImageModelV3
}
export function createNewApi(options: NewApiProviderSettings = {}): NewApiProvider {
const { baseURL = '', fetch: customFetch, endpointType } = options
const resolveApiKey = () =>
loadApiKey({ apiKey: options.apiKey, environmentVariableName: 'NEWAPI_API_KEY', description: 'NewAPI' })
const authHeaders = (): Record<string, string> => ({
Authorization: `Bearer ${resolveApiKey()}`,
'Content-Type': 'application/json',
...options.headers
})
const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}`
const createAnthropicModel = (modelId: string) => {
const headers = authHeaders()
return new AnthropicMessagesLanguageModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.anthropic`,
baseURL,
headers: () => ({ ...headers, 'x-api-key': resolveApiKey() }),
fetch: customFetch,
supportedUrls: () => ({ 'image/*': [/^https?:\/\/.*$/] })
})
}
const createGeminiModel = (modelId: string) => {
const headers = authHeaders()
return new GoogleGenerativeAILanguageModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.google`,
baseURL,
headers: () => ({ ...headers, 'x-goog-api-key': resolveApiKey() }),
fetch: customFetch,
generateId: () => `${NEWAPI_PROVIDER_NAME}-${Date.now()}`,
supportedUrls: () => ({})
})
}
const createResponsesModel = (modelId: string) =>
new OpenAIResponsesLanguageModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.openai-response`,
url,
headers: authHeaders,
fetch: customFetch
})
const createCompatibleModel = (modelId: string) =>
new OpenAICompatibleChatLanguageModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.chat`,
url,
headers: authHeaders,
fetch: customFetch
})
const createChatModel = (modelId: string): LanguageModelV3 => {
switch (endpointType) {
case 'anthropic':
return createAnthropicModel(modelId)
case 'gemini':
return createGeminiModel(modelId)
case 'openai-response':
return createResponsesModel(modelId)
case 'openai':
case 'image-generation':
return createCompatibleModel(modelId)
default:
return createCompatibleModel(modelId)
}
}
const provider = (modelId: string) => createChatModel(modelId)
provider.specificationVersion = 'v3' as const
provider.languageModel = createChatModel
provider.embeddingModel = (modelId: string) =>
new OpenAICompatibleEmbeddingModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.embedding`,
url,
headers: authHeaders,
fetch: customFetch
})
provider.imageModel = (modelId: string) =>
new OpenAICompatibleImageModel(modelId, {
provider: `${NEWAPI_PROVIDER_NAME}.image`,
url,
headers: authHeaders,
fetch: customFetch
})
return provider as NewApiProvider
}

View File

@@ -1,139 +0,0 @@
/**
* Type System Tests for Auto-Extracted Provider Types
*/
import type { AppProviderId } from '@renderer/aiCore/types'
import { describe, expect, expectTypeOf, it } from 'vitest'
import { extensions } from '../index'
describe('Auto-Extracted Type System', () => {
describe('Runtime and Type Consistency', () => {
it('运行时 IDs 应该自动提取到类型系统', () => {
// 从运行时获取所有 IDs包括主 ID 和别名)
const runtimeIds = extensions.flatMap((ext) => ext.getProviderIds())
// 🎯 Zero maintenance - 不再需要手动声明类型!
// 类型系统会自动从 projectExtensions 数组中提取所有 IDs
// 验证主要的 project provider IDs
const expectedMainIds: AppProviderId[] = [
'google-vertex',
'google-vertex-anthropic',
'github-copilot-openai-compatible',
'bedrock',
'perplexity',
'mistral',
'huggingface',
'gateway',
'cerebras',
'ollama'
]
// 验证别名
const expectedAliases: AppProviderId[] = [
'vertexai',
'vertexai-anthropic',
'copilot',
'github-copilot',
'aws-bedrock',
'hf',
'hugging-face',
'ai-gateway'
]
// 验证所有期望的 ID 都存在于运行时
;[...expectedMainIds, ...expectedAliases].forEach((id) => {
expect(runtimeIds).toContain(id)
})
// 验证数量一致
const uniqueRuntimeIds = [...new Set(runtimeIds)]
expect(uniqueRuntimeIds.length).toBeGreaterThanOrEqual(expectedMainIds.length + expectedAliases.length)
})
it('每个 extension 应该至少有一个 provider ID', () => {
extensions.forEach((ext) => {
const ids = ext.getProviderIds()
expect(ids.length).toBeGreaterThan(0)
})
})
})
describe('Type Inference - Auto-Extracted', () => {
// 🎯 Zero maintenance! These tests validate compile-time type inference
// 类型从 projectExtensions 数组自动提取,无需手动维护
it('应该接受核心 provider IDs', () => {
// 编译时类型检查 - AppProviderId 包含所有 core IDs
const coreIds: AppProviderId[] = [
'openai',
'anthropic',
'google',
'azure',
'deepseek',
'xai',
'openai-compatible',
'openrouter',
'cherryin'
]
// 运行时验证(确保类型存在)
expect(coreIds.length).toBeGreaterThan(0)
})
it('应该接受项目特定 provider IDs', () => {
// 编译时类型检查 - 自动从 projectExtensions 提取
const projectIds: AppProviderId[] = [
'google-vertex',
'google-vertex-anthropic',
'github-copilot-openai-compatible',
'bedrock',
'perplexity',
'mistral',
'huggingface',
'gateway',
'cerebras',
'ollama'
]
// 运行时验证
expect(projectIds.length).toBe(10)
})
it('应该接受项目特定 provider 别名', () => {
// 编译时类型检查 - 别名也自动提取
const aliases: AppProviderId[] = [
'vertexai',
'vertexai-anthropic',
'copilot',
'github-copilot',
'aws-bedrock',
'hf',
'hugging-face',
'ai-gateway'
]
// 运行时验证
expect(aliases.length).toBe(8)
})
it('AppProviderId 应该包含项目和核心的所有 IDs', () => {
// 编译时验证 - 统一类型系统测试
// ✅ 项目 IDs 应该在 AppProviderId 中
type Check1 = 'google-vertex' extends AppProviderId ? true : false
type Check2 = 'ollama' extends AppProviderId ? true : false
type Check3 = 'vertexai' extends AppProviderId ? true : false
// ✅ 核心 IDs 也应该在 AppProviderId 中(统一类型系统)
type Check4 = 'openai' extends AppProviderId ? true : false
type Check5 = 'anthropic' extends AppProviderId ? true : false
expectTypeOf<Check1>().toEqualTypeOf<true>()
expectTypeOf<Check2>().toEqualTypeOf<true>()
expectTypeOf<Check3>().toEqualTypeOf<true>()
expectTypeOf<Check4>().toEqualTypeOf<true>()
expectTypeOf<Check5>().toEqualTypeOf<true>()
})
})
})

View File

@@ -1,235 +0,0 @@
/**
* Cherry Studio 项目特定的 Provider Extensions
* 用于支持运行时动态导入的 AI Providers
*/
import type { AmazonBedrockProvider } from '@ai-sdk/amazon-bedrock'
import { type AmazonBedrockProviderSettings, createAmazonBedrock } from '@ai-sdk/amazon-bedrock'
import { type CerebrasProviderSettings, createCerebras } from '@ai-sdk/cerebras'
import { createGateway, type GatewayProviderSettings } from '@ai-sdk/gateway'
import { createVertexAnthropic, type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic/edge'
import { createVertex, type GoogleVertexProvider, type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
import { createGroq, type GroqProviderSettings } from '@ai-sdk/groq'
import { createHuggingFace, type HuggingFaceProviderSettings } from '@ai-sdk/huggingface'
import { createMistral, type MistralProviderSettings } from '@ai-sdk/mistral'
import { createPerplexity, type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import type { ProviderV3 } from '@ai-sdk/provider'
import { createTogetherAI, type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
import { ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
import {
createGitHubCopilotOpenAICompatible,
type GitHubCopilotProviderSettings
} from '@opeoginni/github-copilot-openai-compatible'
import { SystemProviderIds } from '@types'
import type { OllamaProviderSettings } from 'ollama-ai-provider-v2'
import { createOllama } from 'ollama-ai-provider-v2'
import { createVoyage, type VoyageProviderSettings } from 'voyage-ai-provider'
import { type AihubmixProviderSettings, createAihubmix } from '../custom/aihubmix-provider'
import { createNewApi, type NewApiProviderSettings } from '../custom/newapi-provider'
/**
* Google Vertex AI Extension
*/
export const GoogleVertexExtension = ProviderExtension.create({
name: 'google-vertex',
aliases: ['vertexai'] as const,
supportsImageGeneration: true,
create: createVertex,
toolFactories: {
webSearch:
(provider: GoogleVertexProvider) =>
(config: NonNullable<Parameters<GoogleVertexProvider['tools']['googleSearch']>[0]>) => ({
tools: { webSearch: provider.tools.googleSearch(config) }
}),
urlContext:
(provider: GoogleVertexProvider) =>
(config: NonNullable<Parameters<GoogleVertexProvider['tools']['urlContext']>[0]>) => ({
tools: { urlContext: provider.tools.urlContext(config) }
})
}
} as const satisfies ProviderExtensionConfig<GoogleVertexProviderSettings, GoogleVertexProvider, 'google-vertex'>)
/**
* Google Vertex AI Anthropic Extension
*/
export const GoogleVertexAnthropicExtension = ProviderExtension.create({
name: 'google-vertex-anthropic',
aliases: ['vertexai-anthropic'] as const,
supportsImageGeneration: true,
create: createVertexAnthropic,
toolFactories: {
webSearch:
(provider: GoogleVertexAnthropicProvider) =>
(config: NonNullable<Parameters<GoogleVertexAnthropicProvider['tools']['webSearch_20250305']>[0]>) => ({
tools: { webSearch: provider.tools.webSearch_20250305(config) }
})
}
} as const satisfies ProviderExtensionConfig<
GoogleVertexProviderSettings,
GoogleVertexAnthropicProvider,
'google-vertex-anthropic'
>)
/**
* GitHub Copilot Extension
*/
export const GitHubCopilotExtension = ProviderExtension.create({
name: 'github-copilot-openai-compatible',
aliases: ['copilot', 'github-copilot'] as const,
supportsImageGeneration: false,
create: (options?: GitHubCopilotProviderSettings) =>
// GitHubCopilot并没有完整的实现ProviderV3
createGitHubCopilotOpenAICompatible(options) as unknown as ProviderV3
} as const satisfies ProviderExtensionConfig<
GitHubCopilotProviderSettings,
ProviderV3,
'github-copilot-openai-compatible'
>)
/**
* Amazon Bedrock Extension
*/
export const BedrockExtension = ProviderExtension.create({
name: 'bedrock',
aliases: ['aws-bedrock'] as const,
supportsImageGeneration: true,
create: createAmazonBedrock,
toolFactories: {
webSearch:
(provider: AmazonBedrockProvider) =>
(config: NonNullable<Parameters<AmazonBedrockProvider['tools']['webSearch_20260209']>[0]>) => ({
tools: { webSearch: provider.tools.webSearch_20260209(config) }
}),
urlContext:
(provider: AmazonBedrockProvider) =>
(config: NonNullable<Parameters<AmazonBedrockProvider['tools']['webFetch_20260209']>[0]>) => ({
tools: { urlContext: provider.tools.webFetch_20260209(config) }
})
}
} as const satisfies ProviderExtensionConfig<AmazonBedrockProviderSettings, AmazonBedrockProvider, 'bedrock'>)
/**
* Perplexity Extension
*/
export const PerplexityExtension = ProviderExtension.create({
name: 'perplexity',
supportsImageGeneration: false,
create: createPerplexity
} as const satisfies ProviderExtensionConfig<PerplexityProviderSettings, ProviderV3, 'perplexity'>)
/**
* Mistral Extension
*/
export const MistralExtension = ProviderExtension.create({
name: 'mistral',
supportsImageGeneration: false,
create: createMistral
} as const satisfies ProviderExtensionConfig<MistralProviderSettings, ProviderV3, 'mistral'>)
/**
* HuggingFace Extension
*/
export const HuggingFaceExtension = ProviderExtension.create({
name: 'huggingface',
aliases: ['hf', 'hugging-face'] as const,
supportsImageGeneration: true,
create: createHuggingFace
} as const satisfies ProviderExtensionConfig<HuggingFaceProviderSettings, ProviderV3, 'huggingface'>)
/**
* Vercel AI Gateway Extension
*/
export const GatewayExtension = ProviderExtension.create({
name: 'gateway',
aliases: ['ai-gateway'] as const,
supportsImageGeneration: true,
create: createGateway
} as const satisfies ProviderExtensionConfig<GatewayProviderSettings, ProviderV3, 'gateway'>)
/**
* Cerebras Extension
*/
export const CerebrasExtension = ProviderExtension.create({
name: 'cerebras',
supportsImageGeneration: false,
create: createCerebras
} as const satisfies ProviderExtensionConfig<CerebrasProviderSettings, ProviderV3, 'cerebras'>)
/**
* Groq Extension
*/
export const GroqExtension = ProviderExtension.create({
name: 'groq',
supportsImageGeneration: false,
create: createGroq
} as const satisfies ProviderExtensionConfig<GroqProviderSettings, ProviderV3, 'groq'>)
/**
* Ollama Extension
*/
export const OllamaExtension = ProviderExtension.create({
name: 'ollama',
supportsImageGeneration: false,
create: (options?: OllamaProviderSettings) => createOllama(options)
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, ProviderV3, 'ollama'>)
/**
* AiHubMix Extension - multi-backend gateway (claude->anthropic, gemini->google, gpt->openai-responses)
*/
export const AiHubMixExtension = ProviderExtension.create({
name: 'aihubmix',
supportsImageGeneration: true,
create: createAihubmix
} as const satisfies ProviderExtensionConfig<AihubmixProviderSettings, ProviderV3, 'aihubmix'>)
/**
* NewAPI Extension - multi-backend gateway routed by endpoint_type
*/
export const NewApiExtension = ProviderExtension.create({
name: 'newapi',
aliases: ['new-api'] as const,
supportsImageGeneration: true,
create: createNewApi
} as const satisfies ProviderExtensionConfig<NewApiProviderSettings, ProviderV3, 'newapi'>)
/**
* Together AI Extension - chat and image generation
*/
export const TogetherAIExtension = ProviderExtension.create({
name: 'togetherai',
aliases: [SystemProviderIds.together] as const,
supportsImageGeneration: true,
create: createTogetherAI
} as const satisfies ProviderExtensionConfig<TogetherAIProviderSettings, ProviderV3, 'togetherai'>)
/**
* Voyage AI Extension - embeddings and reranking
*/
export const VoyageExtension = ProviderExtension.create({
name: 'voyage',
aliases: [SystemProviderIds.voyageai] as const,
supportsImageGeneration: false,
create: createVoyage
} as const satisfies ProviderExtensionConfig<VoyageProviderSettings, ProviderV3, 'voyage'>)
/**
* 所有项目特定的 Extensions
*/
export const extensions = [
GoogleVertexExtension,
GoogleVertexAnthropicExtension,
GitHubCopilotExtension,
BedrockExtension,
PerplexityExtension,
MistralExtension,
HuggingFaceExtension,
GatewayExtension,
CerebrasExtension,
OllamaExtension,
AiHubMixExtension,
NewApiExtension,
VoyageExtension,
TogetherAIExtension,
GroqExtension
] as const

View File

@@ -1,54 +0,0 @@
import { extensionRegistry } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { type Provider, SystemProviderIds } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import { type AppProviderId, appProviderIds } from '../types'
import { extensions } from './extensions'
const logger = loggerService.withContext('ProviderFactory')
for (const extension of extensions) {
if (!extensionRegistry.has(extension.config.name)) {
extensionRegistry.register(extension)
}
}
/**
* 获取 AI SDK Provider ID
*
* 使用运行时类型安全的 appProviderIds 统一解析
* 特殊处理 Azure 端点检测和 OpenAI API 域名检测
*
* @param provider - Provider 配置对象
* @returns AI SDK 标准 provider ID
*/
export function getAiSdkProviderId(provider: Provider): AppProviderId {
// 1. 特殊处理Azure 的 responses 端点检测(必须在别名解析之前)
if (isAzureOpenAIProvider(provider)) {
return isAzureResponsesEndpoint(provider) ? appProviderIds['azure-responses'] : appProviderIds.azure
}
if (provider.id === SystemProviderIds.grok) {
return appProviderIds['xai-responses']
}
if (provider.id in appProviderIds) {
return appProviderIds[provider.id]
}
if (provider.type !== 'openai' && provider.type in appProviderIds) {
return appProviderIds[provider.type]
}
if (provider.apiHost.includes('api.openai.com')) {
return appProviderIds['openai-chat']
}
logger.warn('Provider ID not found in registered extensions, using as-is', {
providerId: provider.id,
providerType: provider.type,
registeredIds: Object.keys(appProviderIds)
})
return provider.id
}

View File

@@ -1,409 +0,0 @@
import { formatPrivateKey, hasProviderConfig, type StringKeys } from '@cherrystudio/ai-core/provider'
import type { AppProviderId, AppProviderSettingsMap } from '@renderer/aiCore/types'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
getAwsBedrockAuthType,
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { getProviderById } from '@renderer/services/ProviderService'
import store from '@renderer/store'
import { type Model, type Provider, SystemProviderIds } from '@renderer/types'
import {
formatApiHost,
formatOllamaApiHost,
formatVertexApiHost,
isWithTrailingSharp,
routeToEndpoint
} from '@renderer/utils/api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isOllamaProvider,
isPerplexityProvider,
isSupportStreamOptionsProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { defaultAppHeaders } from '@shared/utils'
import { cloneDeep, isEmpty } from 'lodash'
import type { ProviderConfig } from '../types'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
// === Types ===
interface BaseConfig {
baseURL: string
apiKey: string
}
interface BuilderContext {
actualProvider: Provider
model: Model
baseConfig: BaseConfig
endpoint?: string
aiSdkProviderId: AppProviderId
}
// === Host Formatting ===
type HostFormatter = {
match: (provider: Provider) => boolean
format: (provider: Provider, appendApiVersion: boolean) => string
}
// WARNING: if any changes are made here, please sync it to src/main/aiCore/provider/providerConfig.ts:formatProviderApiHost
export function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
}
// Anthropic is special: uses anthropicApiHost as source and syncs both fields
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
return formatted
}
const formatters: HostFormatter[] = [
{
match: (p) => p.id === SystemProviderIds.copilot || p.id === SystemProviderIds.github,
format: (p) => formatApiHost(p.apiHost, false)
},
{ match: isCherryAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
{ match: isPerplexityProvider, format: (p) => formatApiHost(p.apiHost, false) },
{ match: isOllamaProvider, format: (p) => formatOllamaApiHost(p.apiHost) },
{ match: isGeminiProvider, format: (p, av) => formatApiHost(p.apiHost, av, 'v1beta') },
{ match: isAzureOpenAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
{ match: isVertexProvider, format: (p) => formatVertexApiHost(p as Parameters<typeof formatVertexApiHost>[0]) }
]
const formatter = formatters.find((f) => f.match(provider))
formatted.apiHost = formatter
? formatter.format(formatted, appendApiVersion)
: formatApiHost(formatted.apiHost, appendApiVersion)
return formatted
}
// === SDK Config Building ===
type ConfigBuilderEntry = {
match: (provider: Provider, aiSdkProviderId: AppProviderId) => boolean
build: (ctx: BuilderContext) => ProviderConfig | Promise<ProviderConfig>
}
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): ProviderConfig | Promise<ProviderConfig> {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const ctx: BuilderContext = {
actualProvider,
model,
baseConfig: { baseURL, apiKey: actualProvider.apiKey },
endpoint,
aiSdkProviderId
}
const builders: ConfigBuilderEntry[] = [
{ match: (p) => p.id === SystemProviderIds.copilot, build: buildCopilotConfig },
{ match: (p) => p.id === 'cherryai', build: buildCherryAIConfig },
{ match: (p) => p.id === 'anthropic' && p.authType === 'oauth', build: buildAnthropicConfig },
{ match: (p) => isOllamaProvider(p), build: buildOllamaConfig },
{ match: (p) => isAzureOpenAIProvider(p), build: buildAzureConfig },
{ match: (_, id) => id === 'bedrock', build: buildBedrockConfig },
{ match: (_, id) => id === 'google-vertex', build: buildVertexConfig },
{ match: (_, id) => id === 'cherryin', build: buildCherryinConfig },
{ match: (_, id) => id === 'newapi', build: buildNewApiConfig },
{ match: (_, id) => id === 'aihubmix', build: buildAiHubMixConfig }
]
const builder = builders.find((b) => b.match(actualProvider, aiSdkProviderId))
if (builder) {
return builder.build(ctx)
}
// SDK-supported provider → generic config; otherwise → openai-compatible fallback
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
return buildGenericProviderConfig(ctx)
}
return buildOpenAICompatibleConfig(ctx)
}
// === Public API ===
export function getActualProvider(model: Model): Provider {
return adaptProvider({ provider: getProviderByModel(model), model })
}
export function adaptProvider({ provider }: { provider: Provider; model?: Model }): Provider {
return formatProviderApiHost(cloneDeep(provider))
}
export function isModernSdkSupported(provider: Provider): boolean {
return hasProviderConfig(getAiSdkProviderId(provider))
}
// === Config Builders ===
function buildCommonOptions(ctx: BuilderContext) {
const options: Record<string, any> = {
headers: {
...defaultAppHeaders(),
...ctx.actualProvider.extra_headers
}
}
if (ctx.aiSdkProviderId === 'openai') {
options.headers['X-Api-Key'] = ctx.baseConfig.apiKey
}
return options
}
async function buildCopilotConfig(ctx: BuilderContext): Promise<ProviderConfig<'github-copilot-openai-compatible'>> {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const headers = { ...COPILOT_DEFAULT_HEADERS, ...storedHeaders }
const { token } = await window.api.copilot.getToken(headers)
return {
providerId: 'github-copilot-openai-compatible',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
apiKey: token,
headers: { ...headers, ...ctx.actualProvider.extra_headers },
name: ctx.actualProvider.id
}
}
}
function buildOllamaConfig(ctx: BuilderContext): ProviderConfig<'ollama'> {
const headers: ProviderConfig<'ollama'>['providerSettings']['headers'] = {
...defaultAppHeaders(),
...ctx.actualProvider.extra_headers
}
if (!isEmpty(ctx.baseConfig.apiKey)) {
headers.Authorization = `Bearer ${ctx.baseConfig.apiKey}`
}
return {
providerId: 'ollama',
endpoint: ctx.endpoint,
providerSettings: { ...ctx.baseConfig, headers }
}
}
function buildBedrockConfig(ctx: BuilderContext): ProviderConfig<'bedrock'> {
const authType = getAwsBedrockAuthType()
const region = getAwsBedrockRegion()
const base = { providerId: 'bedrock' as const, endpoint: ctx.endpoint }
if (authType === 'apiKey') {
return { ...base, providerSettings: { ...ctx.baseConfig, region, apiKey: getAwsBedrockApiKey() } }
}
return {
...base,
providerSettings: {
...ctx.baseConfig,
region,
accessKeyId: getAwsBedrockAccessKeyId(),
secretAccessKey: getAwsBedrockSecretAccessKey()
}
}
}
function buildVertexConfig(
ctx: BuilderContext
): ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'> {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(ctx.actualProvider)
// Vertex 上的 Claude 模型走 google-vertex-anthropic variant
const isAnthropic = ctx.aiSdkProviderId === 'google-vertex-anthropic' || ctx.model.id.startsWith('claude')
const baseURL = ctx.baseConfig.baseURL + (isAnthropic ? '/publishers/anthropic/models' : '/publishers/google')
const creds = { ...googleCredentials, privateKey: formatPrivateKey(googleCredentials.privateKey) }
return {
providerId: isAnthropic ? 'google-vertex-anthropic' : 'google-vertex',
endpoint: ctx.endpoint,
providerSettings: { ...ctx.baseConfig, baseURL, project, location, googleCredentials: creds }
} as ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'>
}
function buildCherryinConfig(ctx: BuilderContext): ProviderConfig<'cherryin'> {
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
return {
providerId: 'cherryin',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
endpointType: ctx.model.endpoint_type,
anthropicBaseURL: cherryinProvider ? cherryinProvider.anthropicApiHost + '/v1' : undefined,
geminiBaseURL: cherryinProvider ? cherryinProvider.apiHost + '/v1beta' : undefined,
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
}
}
}
async function buildCherryAIConfig(ctx: BuilderContext): Promise<ProviderConfig<'openai-compatible'>> {
return {
providerId: 'openai-compatible',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
name: ctx.actualProvider.id,
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers },
fetch: async (input: RequestInfo | URL, init?: RequestInit) => {
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: init?.body && typeof init.body === 'string' ? JSON.parse(init.body) : undefined
})
return fetch(input, { ...init, headers: { ...init?.headers, ...signature } })
}
}
}
}
function formatAzureBaseURL(baseURL: string, forAnthropic: boolean): string {
// Normalize: strip trailing /v1 and /openai that user may have included
const normalized = baseURL.replace(/\/v1$/, '').replace(/\/openai$/, '')
// Azure OpenAI endpoints need /openai suffix; Azure Anthropic does not
return forAnthropic ? normalized : normalized + '/openai'
}
function buildAzureConfig(
ctx: BuilderContext
): ProviderConfig<'azure'> | ProviderConfig<'azure-responses'> | ProviderConfig<'azure-anthropic'> {
// Azure 上的 Claude 模型走 azure-anthropic variant内部使用 Anthropic SDK
if (ctx.model.id.startsWith('claude')) {
return {
providerId: 'azure-anthropic',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
baseURL: formatAzureBaseURL(ctx.baseConfig.baseURL, true),
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
}
}
}
const apiVersion = ctx.actualProvider.apiVersion?.trim()
const useResponsesMode = apiVersion && ['preview', 'v1'].includes(apiVersion)
const providerSettings: ProviderConfig<'azure'>['providerSettings'] = {
...ctx.baseConfig,
baseURL: formatAzureBaseURL(ctx.baseConfig.baseURL, false),
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
}
if (apiVersion) {
providerSettings.apiVersion = apiVersion
if (!useResponsesMode) {
providerSettings.useDeploymentBasedUrls = true
}
}
return {
providerId: useResponsesMode ? 'azure-responses' : 'azure',
endpoint: ctx.endpoint,
providerSettings
} as ProviderConfig<'azure'> | ProviderConfig<'azure-responses'>
}
async function buildAnthropicConfig(ctx: BuilderContext): Promise<ProviderConfig<'anthropic'>> {
const oauthToken: string = await window.api.anthropic_oauth.getAccessToken()
return {
providerId: 'anthropic',
endpoint: ctx.endpoint,
providerSettings: {
baseURL: 'https://api.anthropic.com/v1',
apiKey: '',
headers: {
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
Authorization: `Bearer ${oauthToken}`
}
}
}
}
function buildOpenAICompatibleConfig(ctx: BuilderContext): ProviderConfig<'openai-compatible'> {
const commonOptions = buildCommonOptions(ctx)
const includeUsage = isSupportStreamOptionsProvider(ctx.actualProvider)
? store.getState().settings.openAI?.streamOptions?.includeUsage
: undefined
return {
providerId: 'openai-compatible',
endpoint: ctx.endpoint,
providerSettings: { ...ctx.baseConfig, ...commonOptions, name: ctx.actualProvider.id, includeUsage }
}
}
function buildGenericProviderConfig(ctx: BuilderContext): ProviderConfig {
const commonOptions = buildCommonOptions(ctx)
return {
providerId: ctx.aiSdkProviderId as StringKeys<AppProviderSettingsMap>,
endpoint: ctx.endpoint,
providerSettings: { ...ctx.baseConfig, ...commonOptions }
}
}
function buildAiHubMixConfig(ctx: BuilderContext): ProviderConfig<'aihubmix'> {
return {
providerId: 'aihubmix',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
}
}
}
function formatNewApiBaseURL(baseURL: string, endpointType?: string): string {
switch (endpointType) {
case 'gemini':
return formatApiHost(baseURL, true, 'v1beta')
case 'anthropic':
return formatApiHost(baseURL, false)
default:
return formatApiHost(baseURL, true)
}
}
function buildNewApiConfig(ctx: BuilderContext): ProviderConfig<'newapi'> {
const baseURL = formatNewApiBaseURL(ctx.baseConfig.baseURL, ctx.model.endpoint_type)
return {
providerId: 'newapi',
endpoint: ctx.endpoint,
providerSettings: {
...ctx.baseConfig,
baseURL,
endpointType: ctx.model.endpoint_type,
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
}
}
}

View File

@@ -1,333 +0,0 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`listModels > AIHubMix > should convert real AIHubMix response with model_id and model_name 1`] = `
[
{
"description": "Qwen 3.6, the native vision-language Plus series model.",
"group": "aihubmix",
"id": "qwen3.6-plus",
"name": "Qwen3.6 Plus",
"provider": "aihubmix",
},
{
"description": "Claude Sonnet 4.6 delivers frontier intelligence at scale.",
"group": "aihubmix",
"id": "claude-sonnet-4-6",
"name": "Claude Sonnet 4.6",
"provider": "aihubmix",
},
{
"description": "GPT-5.4 is our frontier model for complex professional work.",
"group": "aihubmix",
"id": "gpt-5.4",
"name": "GPT 5.4",
"provider": "aihubmix",
},
{
"description": "A new-generation professional-grade multimodal video-creation model.",
"group": "aihubmix",
"id": "doubao-seedance-2-0-260128",
"name": "Doubao Seedance 2.0 260128",
"provider": "aihubmix",
},
]
`;
exports[`listModels > Gemini > should strip models/ prefix and use displayName from real response 1`] = `
[
{
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"group": "gemini",
"id": "gemini-2.5-flash",
"name": "Gemini 2.5 Flash",
"provider": "gemini",
},
{
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"group": "gemini",
"id": "gemini-2.5-pro",
"name": "Gemini 2.5 Pro",
"provider": "gemini",
},
{
"description": "Gemini 2.0 Flash",
"group": "gemini",
"id": "gemini-2.0-flash",
"name": "Gemini 2.0 Flash",
"provider": "gemini",
},
{
"description": "Stable version of Gemini 2.0 Flash, our fast and versatile multimodal model for scaling across diverse tasks, released in January of 2025.",
"group": "gemini",
"id": "gemini-2.0-flash-001",
"name": "Gemini 2.0 Flash 001",
"provider": "gemini",
},
{
"description": "Stable version of Gemini 2.0 Flash-Lite",
"group": "gemini",
"id": "gemini-2.0-flash-lite-001",
"name": "Gemini 2.0 Flash-Lite 001",
"provider": "gemini",
},
{
"description": "Gemini 2.0 Flash-Lite",
"group": "gemini",
"id": "gemini-2.0-flash-lite",
"name": "Gemini 2.0 Flash-Lite",
"provider": "gemini",
},
]
`;
exports[`listModels > OpenAI-compatible (DeepSeek) > should convert real DeepSeek response 1`] = `
[
{
"group": "deepseek",
"id": "deepseek-chat",
"name": "deepseek-chat",
"owned_by": "deepseek",
"provider": "deepseek",
},
{
"group": "deepseek",
"id": "deepseek-reasoner",
"name": "deepseek-reasoner",
"owned_by": "deepseek",
"provider": "deepseek",
},
]
`;
exports[`listModels > OpenAI-compatible (Groq) > should convert real Groq response with owned_by 1`] = `
[
{
"group": "qwen",
"id": "qwen/qwen3-32b",
"name": "qwen/qwen3-32b",
"owned_by": "Alibaba Cloud",
"provider": "groq",
},
{
"group": "groq",
"id": "groq/compound-mini",
"name": "groq/compound-mini",
"owned_by": "Groq",
"provider": "groq",
},
]
`;
exports[`listModels > OpenAI-compatible (SiliconFlow) > should handle nested slash IDs for group extraction 1`] = `
[
{
"group": "Pro",
"id": "Pro/MiniMaxAI/MiniMax-M2.5",
"name": "Pro/MiniMaxAI/MiniMax-M2.5",
"owned_by": "",
"provider": "silicon",
},
{
"group": "Pro",
"id": "Pro/zai-org/GLM-5",
"name": "Pro/zai-org/GLM-5",
"owned_by": "",
"provider": "silicon",
},
{
"group": "Pro",
"id": "Pro/moonshotai/Kimi-K2.5",
"name": "Pro/moonshotai/Kimi-K2.5",
"owned_by": "",
"provider": "silicon",
},
{
"group": "Pro",
"id": "Pro/zai-org/GLM-4.7",
"name": "Pro/zai-org/GLM-4.7",
"owned_by": "",
"provider": "silicon",
},
{
"group": "deepseek-ai",
"id": "deepseek-ai/DeepSeek-V3.2",
"name": "deepseek-ai/DeepSeek-V3.2",
"owned_by": "",
"provider": "silicon",
},
{
"group": "Pro",
"id": "Pro/deepseek-ai/DeepSeek-V3.2",
"name": "Pro/deepseek-ai/DeepSeek-V3.2",
"owned_by": "",
"provider": "silicon",
},
]
`;
exports[`listModels > OpenRouter > should merge chat and embedding endpoints from real response 1`] = `
[
{
"group": "xiaomi",
"id": "xiaomi/mimo-v2-omni",
"name": "xiaomi/mimo-v2-omni",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "xiaomi",
"id": "xiaomi/mimo-v2-pro",
"name": "xiaomi/mimo-v2-pro",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "minimax",
"id": "minimax/minimax-m2.7",
"name": "minimax/minimax-m2.7",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "openai",
"id": "openai/gpt-5.4-nano",
"name": "openai/gpt-5.4-nano",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "openai",
"id": "openai/gpt-5.4-mini",
"name": "openai/gpt-5.4-mini",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "mistralai",
"id": "mistralai/mistral-small-2603",
"name": "mistralai/mistral-small-2603",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "z-ai",
"id": "z-ai/glm-5-turbo",
"name": "z-ai/glm-5-turbo",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "x-ai",
"id": "x-ai/grok-4.20-multi-agent-beta",
"name": "x-ai/grok-4.20-multi-agent-beta",
"owned_by": null,
"provider": "openrouter",
},
{
"group": "openai",
"id": "openai/text-embedding-3-large",
"name": "openai/text-embedding-3-large",
"owned_by": undefined,
"provider": "openrouter",
},
]
`;
exports[`listModels > PPIO > should merge all three endpoints from real response 1`] = `
[
{
"group": "minimax",
"id": "minimax/minimax-m2.7",
"name": "minimax/minimax-m2.7",
"owned_by": "unknown",
"provider": "ppio",
},
{
"group": "minimax",
"id": "minimax/minimax-m2.5-highspeed",
"name": "minimax/minimax-m2.5-highspeed",
"owned_by": "unknown",
"provider": "ppio",
},
{
"group": "qwen",
"id": "qwen/qwen3.5-27b",
"name": "qwen/qwen3.5-27b",
"owned_by": "unknown",
"provider": "ppio",
},
{
"group": "qwen",
"id": "qwen/qwen3.5-122b-a10b",
"name": "qwen/qwen3.5-122b-a10b",
"owned_by": "unknown",
"provider": "ppio",
},
{
"group": "qwen",
"id": "qwen/qwen3.5-35b-a3b",
"name": "qwen/qwen3.5-35b-a3b",
"owned_by": "unknown",
"provider": "ppio",
},
{
"group": "BAAI",
"id": "BAAI/bge-m3",
"name": "BAAI/bge-m3",
"owned_by": "BAAI",
"provider": "ppio",
},
{
"group": "BAAI",
"id": "BAAI/bge-reranker-v2-m3",
"name": "BAAI/bge-reranker-v2-m3",
"owned_by": "BAAI",
"provider": "ppio",
},
]
`;
exports[`listModels > Together > should use display_name and organization from real response 1`] = `
[
{
"description": null,
"group": "hexgrad",
"id": "hexgrad/Kokoro-82M",
"name": "Kokoro 82M",
"owned_by": "Hexgrad",
"provider": "together",
},
{
"description": null,
"group": "cartesia",
"id": "cartesia/sonic",
"name": "Cartesia Sonic",
"owned_by": "Cartesia",
"provider": "together",
},
{
"description": null,
"group": "black-forest-labs",
"id": "black-forest-labs/FLUX.1-krea-dev",
"name": "FLUX.1 Krea [dev]",
"owned_by": "Black Forest Labs",
"provider": "together",
},
{
"description": null,
"group": "google",
"id": "google/imagen-4.0-preview",
"name": "Google Imagen 4.0 Preview",
"owned_by": "Google",
"provider": "together",
},
{
"description": null,
"group": "cartesia",
"id": "cartesia/sonic-2",
"name": "Cartesia Sonic 2",
"owned_by": "Cartesia",
"provider": "together",
},
]
`;

View File

@@ -1,411 +0,0 @@
/**
* ModelListService conversion tests
* Uses real API responses captured from providers to verify model conversion
*/
import type { Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const mockGetFromApi = vi.fn()
vi.mock('@ai-sdk/provider-utils', () => ({
createJsonResponseHandler: vi.fn(() => 'json-handler'),
createJsonErrorResponseHandler: vi.fn(() => 'error-handler'),
getFromApi: (...args: unknown[]) => mockGetFromApi(...args),
zodSchema: vi.fn((s: unknown) => s)
}))
vi.mock('@renderer/utils', () => ({
formatApiHost: (host: string) => host?.replace(/\/$/, ''),
withoutTrailingSlash: (s: string) => s?.replace(/\/$/, ''),
getLowerBaseModelName: (id: string) => id.toLowerCase()
}))
vi.mock('@renderer/utils/provider', () => ({
isAIGatewayProvider: (p: Provider) => p.id === 'gateway',
isGeminiProvider: (p: Provider) => p.id === 'gemini' || p.type === 'gemini',
isOllamaProvider: (p: Provider) => p.id === 'ollama' || p.type === 'ollama'
}))
vi.mock('@shared/utils', () => ({
defaultAppHeaders: () => ({ 'X-App': 'CherryStudio' })
}))
const { listModels } = await import('../listModels')
// === Real API response fixtures (captured 2026-03-19) ===
// From https://openrouter.ai/api/v1/models (public, no auth)
const REAL_OPENROUTER = {
data: [
{ id: 'xiaomi/mimo-v2-omni', object: 'model', created: 1773863703, owned_by: null },
{ id: 'xiaomi/mimo-v2-pro', object: 'model', created: 1773863643, owned_by: null },
{ id: 'minimax/minimax-m2.7', object: 'model', created: 1773836697, owned_by: null },
{ id: 'openai/gpt-5.4-nano', object: 'model', created: 1773748187, owned_by: null },
{ id: 'openai/gpt-5.4-mini', object: 'model', created: 1773748178, owned_by: null },
{ id: 'mistralai/mistral-small-2603', object: 'model', created: 1773695685, owned_by: null },
{ id: 'z-ai/glm-5-turbo', object: 'model', created: 1773583573, owned_by: null },
{ id: 'x-ai/grok-4.20-multi-agent-beta', object: 'model', created: 1773325367, owned_by: null }
]
}
// From https://api.deepseek.com/v1/models
const REAL_DEEPSEEK = {
object: 'list',
data: [
{ id: 'deepseek-chat', object: 'model', owned_by: 'deepseek' },
{ id: 'deepseek-reasoner', object: 'model', owned_by: 'deepseek' }
]
}
// From https://generativelanguage.googleapis.com/v1beta/models
const REAL_GEMINI = {
models: [
{
name: 'models/gemini-2.5-flash',
displayName: 'Gemini 2.5 Flash',
description:
'Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.'
},
{
name: 'models/gemini-2.5-pro',
displayName: 'Gemini 2.5 Pro',
description: 'Stable release (June 17th, 2025) of Gemini 2.5 Pro'
},
{
name: 'models/gemini-2.0-flash',
displayName: 'Gemini 2.0 Flash',
description: 'Gemini 2.0 Flash'
},
{
name: 'models/gemini-2.0-flash-001',
displayName: 'Gemini 2.0 Flash 001',
description:
'Stable version of Gemini 2.0 Flash, our fast and versatile multimodal model for scaling across diverse tasks, released in January of 2025.'
},
{
name: 'models/gemini-2.0-flash-lite-001',
displayName: 'Gemini 2.0 Flash-Lite 001',
description: 'Stable version of Gemini 2.0 Flash-Lite'
},
{
name: 'models/gemini-2.0-flash-lite',
displayName: 'Gemini 2.0 Flash-Lite',
description: 'Gemini 2.0 Flash-Lite'
}
]
}
// From https://api.together.xyz/v1/models
const REAL_TOGETHER = [
{ id: 'hexgrad/Kokoro-82M', display_name: 'Kokoro 82M', organization: 'Hexgrad', description: null },
{ id: 'cartesia/sonic', display_name: 'Cartesia Sonic', organization: 'Cartesia', description: null },
{
id: 'black-forest-labs/FLUX.1-krea-dev',
display_name: 'FLUX.1 Krea [dev]',
organization: 'Black Forest Labs',
description: null
},
{
id: 'google/imagen-4.0-preview',
display_name: 'Google Imagen 4.0 Preview',
organization: 'Google',
description: null
},
{ id: 'cartesia/sonic-2', display_name: 'Cartesia Sonic 2', organization: 'Cartesia', description: null }
]
// From https://api.siliconflow.cn/v1/models (OpenAI-compatible)
const REAL_SILICONFLOW = {
data: [
{ id: 'Pro/MiniMaxAI/MiniMax-M2.5', object: 'model', owned_by: '' },
{ id: 'Pro/zai-org/GLM-5', object: 'model', owned_by: '' },
{ id: 'Pro/moonshotai/Kimi-K2.5', object: 'model', owned_by: '' },
{ id: 'Pro/zai-org/GLM-4.7', object: 'model', owned_by: '' },
{ id: 'deepseek-ai/DeepSeek-V3.2', object: 'model', owned_by: '' },
{ id: 'Pro/deepseek-ai/DeepSeek-V3.2', object: 'model', owned_by: '' }
]
}
// From https://api.groq.com/openai/v1/models
const REAL_GROQ = {
data: [
{ id: 'qwen/qwen3-32b', object: 'model', created: 1748396646, owned_by: 'Alibaba Cloud' },
{ id: 'groq/compound-mini', object: 'model', created: 1756949707, owned_by: 'Groq' }
]
}
// From https://api.ppinfra.com/v3/openai/models
const REAL_PPIO_CHAT = {
data: [
{ id: 'minimax/minimax-m2.7', object: 'model', owned_by: 'unknown' },
{ id: 'minimax/minimax-m2.5-highspeed', object: 'model', owned_by: 'unknown' },
{ id: 'qwen/qwen3.5-27b', object: 'model', owned_by: 'unknown' },
{ id: 'qwen/qwen3.5-122b-a10b', object: 'model', owned_by: 'unknown' },
{ id: 'qwen/qwen3.5-35b-a3b', object: 'model', owned_by: 'unknown' }
]
}
// From https://aihubmix.com/api/v1/models (custom schema with model_id/model_name)
const REAL_AIHUBMIX = {
data: [
{
model_id: 'qwen3.6-plus',
model_name: 'Qwen3.6 Plus',
developer_id: 13,
desc: 'Qwen 3.6, the native vision-language Plus series model.',
pricing: { cache_read: 0.0282, cache_write: 0.3525, input: 0.282, output: 1.692 },
types: 'llm',
features: 'tools,function_calling,structured_outputs,web,long_context,thinking',
input_modalities: 'text,image,video',
endpoints: '',
max_output: 64000,
context_length: 991000
},
{
model_id: 'claude-sonnet-4-6',
model_name: 'Claude Sonnet 4.6',
developer_id: 2,
desc: 'Claude Sonnet 4.6 delivers frontier intelligence at scale.',
pricing: { cache_read: 0.3, cache_write: 3.75, input: 3, output: 15 },
types: 'llm',
features: 'thinking,tools,function_calling,structured_outputs',
input_modalities: 'text,image',
endpoints: 'chat_completions,gemini_api,claude_api',
max_output: 64000,
context_length: 1000000
},
{
model_id: 'gpt-5.4',
model_name: 'GPT 5.4',
developer_id: 12,
desc: 'GPT-5.4 is our frontier model for complex professional work.',
pricing: { cache_read: 0.25, input: 2.5, output: 15 },
types: 'llm',
features: 'thinking,function_calling,structured_outputs,web,tools',
input_modalities: 'text,image',
endpoints: '',
max_output: 128000,
context_length: 400000
},
{
model_id: 'doubao-seedance-2-0-260128',
model_name: 'Doubao Seedance 2.0 260128',
developer_id: 4,
desc: 'A new-generation professional-grade multimodal video-creation model.',
pricing: { input: 2, output: 0 },
types: 'video',
features: '',
input_modalities: 'image,text',
endpoints: '',
max_output: 0,
context_length: 0
}
],
message: '',
success: true
}
// === Helpers ===
function makeProvider(overrides: Partial<Provider> & { id: string }): Provider {
return {
name: overrides.id,
type: 'openai',
apiKey: 'sk-test',
apiHost: 'https://api.example.com/v1',
models: [],
isSystem: true,
enabled: true,
...overrides
} as Provider
}
function assertValidModels(models: { id: string; name: string; provider: string; group: string }[]) {
expect(models.length).toBeGreaterThan(0)
for (const m of models) {
expect(m.id).toBeTruthy()
expect(typeof m.id).toBe('string')
expect(m.id).toBe(m.id.trim())
expect(m.name).toBeTruthy()
expect(typeof m.provider).toBe('string')
expect(typeof m.group).toBe('string')
}
}
// === Tests ===
beforeEach(() => {
mockGetFromApi.mockReset()
vi.stubGlobal('window', { ...globalThis.window, keyv: { get: vi.fn(), set: vi.fn() } })
})
describe('listModels', () => {
describe('OpenAI-compatible (DeepSeek)', () => {
it('should convert real DeepSeek response', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_DEEPSEEK })
const models = await listModels(makeProvider({ id: 'deepseek' }))
assertValidModels(models)
expect(models).toMatchSnapshot()
})
})
describe('OpenAI-compatible (SiliconFlow)', () => {
it('should handle nested slash IDs for group extraction', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_SILICONFLOW })
const models = await listModels(makeProvider({ id: 'silicon' }))
assertValidModels(models)
// "Pro/MiniMaxAI/MiniMax-M2.5" -> group "Pro"
expect(models[0].group).toBe('Pro')
// "deepseek-ai/DeepSeek-V3.2" -> group "deepseek-ai"
expect(models[4].group).toBe('deepseek-ai')
expect(models).toMatchSnapshot()
})
})
describe('OpenAI-compatible (Groq)', () => {
it('should convert real Groq response with owned_by', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_GROQ })
const models = await listModels(makeProvider({ id: 'groq' }))
assertValidModels(models)
expect(models[0].owned_by).toBe('Alibaba Cloud')
expect(models[1].owned_by).toBe('Groq')
expect(models).toMatchSnapshot()
})
})
describe('Gemini', () => {
it('should strip models/ prefix and use displayName from real response', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_GEMINI })
const models = await listModels(
makeProvider({ id: 'gemini', type: 'gemini', apiHost: 'https://generativelanguage.googleapis.com/v1beta' })
)
assertValidModels(models)
for (const m of models) {
expect(m.id).not.toMatch(/^models\//)
}
// displayName should be used as name
expect(models[0].name).toBe('Gemini 2.5 Flash')
expect(models[0].id).toBe('gemini-2.5-flash')
expect(models).toMatchSnapshot()
})
})
describe('Together', () => {
it('should use display_name and organization from real response', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_TOGETHER })
const models = await listModels(makeProvider({ id: 'together' }))
assertValidModels(models)
expect(models[0].name).toBe('Kokoro 82M')
expect(models[0].owned_by).toBe('Hexgrad')
expect(models[0].group).toBe('hexgrad')
// FLUX model with org "Black Forest Labs"
expect(models[2].name).toBe('FLUX.1 Krea [dev]')
expect(models[2].owned_by).toBe('Black Forest Labs')
expect(models).toMatchSnapshot()
})
})
describe('OpenRouter', () => {
it('should merge chat and embedding endpoints from real response', async () => {
mockGetFromApi
.mockResolvedValueOnce({ value: REAL_OPENROUTER })
.mockResolvedValueOnce({ value: { data: [{ id: 'openai/text-embedding-3-large', object: 'model' }] } })
const models = await listModels(makeProvider({ id: 'openrouter' }))
assertValidModels(models)
expect(models).toHaveLength(REAL_OPENROUTER.data.length + 1)
// Slash IDs should produce correct group
expect(models.find((m) => m.id === 'xiaomi/mimo-v2-omni')?.group).toBe('xiaomi')
expect(models.find((m) => m.id === 'openai/gpt-5.4-nano')?.group).toBe('openai')
expect(models.find((m) => m.id === 'x-ai/grok-4.20-multi-agent-beta')?.group).toBe('x-ai')
expect(models).toMatchSnapshot()
})
it('should deduplicate across endpoints', async () => {
mockGetFromApi
.mockResolvedValueOnce({ value: { data: [REAL_OPENROUTER.data[0]] } })
.mockResolvedValueOnce({ value: { data: [REAL_OPENROUTER.data[0]] } })
const models = await listModels(makeProvider({ id: 'openrouter' }))
expect(models).toHaveLength(1)
})
it('should handle embedding endpoint failure', async () => {
mockGetFromApi.mockResolvedValueOnce({ value: REAL_OPENROUTER }).mockRejectedValueOnce(new Error('404 Not Found'))
const models = await listModels(makeProvider({ id: 'openrouter' }))
expect(models).toHaveLength(REAL_OPENROUTER.data.length)
})
})
describe('PPIO', () => {
it('should merge all three endpoints from real response', async () => {
mockGetFromApi
.mockResolvedValueOnce({ value: REAL_PPIO_CHAT })
.mockResolvedValueOnce({ value: { data: [{ id: 'BAAI/bge-m3', object: 'model', owned_by: 'BAAI' }] } })
.mockResolvedValueOnce({
value: { data: [{ id: 'BAAI/bge-reranker-v2-m3', object: 'model', owned_by: 'BAAI' }] }
})
const models = await listModels(makeProvider({ id: 'ppio' }))
assertValidModels(models)
expect(models).toHaveLength(7)
expect(models.find((m) => m.id === 'BAAI/bge-m3')?.group).toBe('BAAI')
expect(models).toMatchSnapshot()
})
it('should handle partial endpoint failures', async () => {
mockGetFromApi
.mockResolvedValueOnce({ value: REAL_PPIO_CHAT })
.mockRejectedValueOnce(new Error('timeout'))
.mockRejectedValueOnce(new Error('timeout'))
const models = await listModels(makeProvider({ id: 'ppio' }))
expect(models).toHaveLength(REAL_PPIO_CHAT.data.length)
})
})
describe('AIHubMix', () => {
it('should convert real AIHubMix response with model_id and model_name', async () => {
mockGetFromApi.mockResolvedValue({ value: REAL_AIHUBMIX })
const models = await listModels(makeProvider({ id: 'aihubmix' }))
assertValidModels(models)
expect(models).toHaveLength(4)
// model_name should be used as name
expect(models[0].name).toBe('Qwen3.6 Plus')
expect(models[0].id).toBe('qwen3.6-plus')
expect(models[0].description).toBe('Qwen 3.6, the native vision-language Plus series model.')
// No slash in ID -> group falls back to provider id
expect(models[0].group).toBe('aihubmix')
expect(models[1].name).toBe('Claude Sonnet 4.6')
expect(models[2].name).toBe('GPT 5.4')
expect(models[3].name).toBe('Doubao Seedance 2.0 260128')
expect(models).toMatchSnapshot()
})
it('should deduplicate by model_id', async () => {
const duped = {
...REAL_AIHUBMIX,
data: [REAL_AIHUBMIX.data[0], REAL_AIHUBMIX.data[0], REAL_AIHUBMIX.data[1]]
}
mockGetFromApi.mockResolvedValue({ value: duped })
const models = await listModels(makeProvider({ id: 'aihubmix' }))
expect(models).toHaveLength(2)
})
})
describe('Unsupported providers', () => {
it.each([
['gateway', { id: 'gateway' }],
['aws-bedrock', { id: 'aws-bedrock' }],
['anthropic', { id: 'anthropic' }],
['vertex-anthropic', { id: 'vertex-anthro', type: 'vertex-anthropic' as any }]
])('should return empty for %s', async (_, overrides) => {
const models = await listModels(makeProvider(overrides as any))
expect(models).toEqual([])
expect(mockGetFromApi).not.toHaveBeenCalled()
})
})
describe('Error handling', () => {
it('should return empty on network error', async () => {
mockGetFromApi.mockRejectedValue(new Error('ECONNREFUSED'))
const models = await listModels(makeProvider({ id: 'openai' }))
expect(models).toEqual([])
})
})
})

View File

@@ -1 +0,0 @@
export { listModels } from './listModels'

View File

@@ -1,393 +0,0 @@
/**
* ModelListService - Unified model listing service
* Uses Strategy Registry pattern for provider-specific model fetching
*/
import {
createJsonErrorResponseHandler,
createJsonResponseHandler,
getFromApi as aiSdkGetFromApi,
zodSchema
} from '@ai-sdk/provider-utils'
import { cacheService } from '@data/CacheService'
import { loggerService } from '@logger'
import type { EndpointType, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { formatApiHost, withoutTrailingSlash } from '@renderer/utils'
import { isAIGatewayProvider, isGeminiProvider, isOllamaProvider } from '@renderer/utils/provider'
import { defaultAppHeaders } from '@shared/utils'
import * as z from 'zod'
import {
AIHubMixModelsResponseSchema,
GeminiModelsResponseSchema,
GitHubModelsResponseSchema,
NewApiModelsResponseSchema,
OllamaTagsResponseSchema,
OpenAIModelsResponseSchema,
OVMSConfigResponseSchema,
TogetherModelsResponseSchema
} from './schemas'
const logger = loggerService.withContext('ModelListService')
// === Types ===
type ModelFetcher = {
match: (provider: Provider) => boolean
fetch: (provider: Provider, signal?: AbortSignal) => Promise<Model[]>
}
// === API Layer ===
const ApiErrorSchema = z.object({
error: z
.object({
message: z.string().optional(),
code: z.string().optional()
})
.optional(),
message: z.string().optional()
})
type ApiError = z.infer<typeof ApiErrorSchema>
async function getFromApi<T>({
url,
headers,
responseSchema,
abortSignal
}: {
url: string
headers?: Record<string, string>
responseSchema: z.ZodType<T>
abortSignal?: AbortSignal
}): Promise<T> {
const { value } = await aiSdkGetFromApi({
url,
headers,
successfulResponseHandler: createJsonResponseHandler(zodSchema(responseSchema)),
failedResponseHandler: createJsonErrorResponseHandler({
errorSchema: zodSchema(ApiErrorSchema),
errorToMessage: (error: ApiError) => error.error?.message || error.message || 'Unknown error'
}),
abortSignal
})
return value
}
// === Helpers ===
function getApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = cacheService.getCasual<string>(keyName)
if (!lastUsedKey) {
cacheService.setCasual(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
cacheService.setCasual(keyName, nextKey)
return nextKey
}
function defaultHeaders(provider: Provider): Record<string, string> {
const apiKey = getApiKey(provider)
return {
...defaultAppHeaders(),
...(apiKey ? { Authorization: `Bearer ${apiKey}`, 'X-Api-Key': apiKey } : {}),
...provider.extra_headers
}
}
function defaultGroup(modelId: string, providerId: string): string {
const parts = modelId.split('/')
return parts.length > 1 ? parts[0] : providerId
}
function toModel(id: string, provider: Provider, extra?: Partial<Model>): Model {
return {
id,
name: extra?.name || id,
provider: provider.id,
group: extra?.group || defaultGroup(id, provider.id),
...extra
}
}
function dedup<T>(items: T[], getId: (item: T) => string | undefined): T[] {
const seen = new Set<string>()
return items.filter((item) => {
const id = getId(item)?.trim()
if (!id || seen.has(id)) return false
seen.add(id)
return true
})
}
function pickPreferredString(values: Array<unknown>): string | undefined {
for (const value of values) {
if (typeof value === 'string') {
const trimmed = value.trim()
if (trimmed.length > 0) {
return trimmed
}
}
}
return undefined
}
// === Fetchers ===
const ollamaFetcher: ModelFetcher = {
match: (p) => isOllamaProvider(p),
fetch: async (provider, signal) => {
const baseUrl = withoutTrailingSlash(provider.apiHost)
.replace(/\/v1$/, '')
.replace(/\/api$/, '')
const response = await getFromApi({
url: `${baseUrl}/api/tags`,
headers: defaultHeaders(provider),
responseSchema: OllamaTagsResponseSchema,
abortSignal: signal
})
return dedup(response.models, (m) => m.name).map((m) => toModel(m.name, provider, { owned_by: 'ollama' }))
}
}
const geminiFetcher: ModelFetcher = {
match: (p) => isGeminiProvider(p),
fetch: async (provider, signal) => {
let baseUrl = withoutTrailingSlash(provider.apiHost)
baseUrl = baseUrl.replace(/\/v1(beta)?$/, '')
const response = await getFromApi({
url: `${baseUrl}/v1beta/models?key=${getApiKey(provider)}`,
headers: { ...defaultAppHeaders(), ...provider.extra_headers },
responseSchema: GeminiModelsResponseSchema,
abortSignal: signal
})
return dedup(response.models, (m) => m.name).map((m) => {
const id = m.name.startsWith('models/') ? m.name.slice(7) : m.name
return toModel(id, provider, { name: m.displayName || id, description: m.description })
})
}
}
const githubFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.github,
fetch: async (provider, signal) => {
const [catalogResponse, v1Response] = await Promise.all([
getFromApi({
url: 'https://models.github.ai/catalog/models',
headers: defaultHeaders(provider),
responseSchema: GitHubModelsResponseSchema,
abortSignal: signal
}),
getFromApi({
url: 'https://models.github.ai/v1/models',
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}).catch(() => ({ data: [] as { id: string; owned_by?: string }[] }))
])
const registryModels = catalogResponse.map((m) =>
toModel(m.id, provider, {
name: m.name || m.id,
description: pickPreferredString([m.summary, m.description]),
owned_by: m.publisher
})
)
const v1Models = v1Response.data.map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
return dedup([...registryModels, ...v1Models], (m) => m.id)
}
}
const ovmsFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.ovms,
fetch: async (provider, signal) => {
const baseUrl = formatApiHost(withoutTrailingSlash(provider.apiHost).replace(/\/v1$/, ''), true, 'v1')
const response = await getFromApi({
url: `${baseUrl}/config`,
headers: defaultHeaders(provider),
responseSchema: OVMSConfigResponseSchema,
abortSignal: signal
})
const entries = Object.entries(response).filter(([, info]) =>
info?.model_version_status?.some((v) => v?.state === 'AVAILABLE')
)
return dedup(entries, ([name]) => name).map(([name]) => toModel(name, provider, { owned_by: 'ovms' }))
}
}
const togetherFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.together,
fetch: async (provider, signal) => {
const baseUrl = formatApiHost(provider.apiHost)
const response = await getFromApi({
url: `${baseUrl}/models`,
headers: defaultHeaders(provider),
responseSchema: TogetherModelsResponseSchema,
abortSignal: signal
})
return dedup(response, (m) => m.id).map((m) =>
toModel(m.id, provider, {
name: m.display_name || m.id,
description: m.description,
owned_by: m.organization
})
)
}
}
const newApiFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds['new-api'] || p.type === 'new-api' || p.id === SystemProviderIds.cherryin,
fetch: async (provider, signal) => {
const baseUrl = formatApiHost(provider.apiHost)
const response = await getFromApi({
url: `${baseUrl}/models`,
headers: defaultHeaders(provider),
responseSchema: NewApiModelsResponseSchema,
abortSignal: signal
})
return dedup(response.data, (m) => m.id).map((m) =>
toModel(m.id, provider, {
owned_by: m.owned_by,
supported_endpoint_types: m.supported_endpoint_types as EndpointType[] | undefined
})
)
}
}
const openRouterFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.openrouter,
fetch: async (provider, signal) => {
const [modelsResponse, embedModelsResponse] = await Promise.all([
getFromApi({
url: 'https://openrouter.ai/api/v1/models',
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}),
getFromApi({
url: 'https://openrouter.ai/api/v1/embeddings/models',
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}).catch(() => ({ data: [] }))
])
const all = [...modelsResponse.data, ...embedModelsResponse.data]
return dedup(all, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
}
}
const ppioFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.ppio,
fetch: async (provider, signal) => {
const baseUrl = formatApiHost(provider.apiHost)
const [chat, embed, reranker] = await Promise.all([
getFromApi({
url: `${baseUrl}/models`,
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}),
getFromApi({
url: `${baseUrl}/models?model_type=embedding`,
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}).catch(() => ({ data: [] })),
getFromApi({
url: `${baseUrl}/models?model_type=reranker`,
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
}).catch(() => ({ data: [] }))
])
const all = [...chat.data, ...embed.data, ...reranker.data]
return dedup(all, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
}
}
const aiHubMixFetcher: ModelFetcher = {
match: (p) => p.id === SystemProviderIds.aihubmix,
fetch: async (provider, signal) => {
const response = await getFromApi({
url: `https://aihubmix.com/api/v1/models`,
headers: defaultHeaders(provider),
responseSchema: AIHubMixModelsResponseSchema,
abortSignal: signal
})
return dedup(response.data, (m) => m.model_id).map((m) =>
toModel(m.model_id, provider, {
name: m.model_name || m.model_id,
description: m.desc
})
)
}
}
/** Default fallback: OpenAI-compatible /models endpoint */
const openAICompatibleFetcher: ModelFetcher = {
match: () => true,
fetch: async (provider, signal) => {
const baseUrl = formatApiHost(provider.apiHost)
const response = await getFromApi({
url: `${baseUrl}/models`,
headers: defaultHeaders(provider),
responseSchema: OpenAIModelsResponseSchema,
abortSignal: signal
})
return dedup(response.data, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
}
}
// === Registry (order matters: first match wins) ===
const fetchers: ModelFetcher[] = [
aiHubMixFetcher,
ollamaFetcher,
geminiFetcher,
githubFetcher,
ovmsFetcher,
togetherFetcher,
newApiFetcher,
openRouterFetcher,
ppioFetcher,
openAICompatibleFetcher // always-match fallback, must be last
]
// === Unsupported providers (skip before registry lookup) ===
const UNSUPPORTED_PROVIDERS = new Set<string>([SystemProviderIds['aws-bedrock'], SystemProviderIds.anthropic])
function isUnsupported(provider: Provider): boolean {
return isAIGatewayProvider(provider) || UNSUPPORTED_PROVIDERS.has(provider.id) || provider.type === 'vertex-anthropic'
}
// === Public API ===
export async function listModels(provider: Provider, abortSignal?: AbortSignal): Promise<Model[]> {
try {
if (isUnsupported(provider)) {
logger.warn('Provider does not support model listing via listModels', { providerId: provider.id })
return []
}
const fetcher = fetchers.find((f) => f.match(provider))!
return await fetcher.fetch(provider, abortSignal)
} catch (error) {
logger.error('Error listing models:', error as Error, { providerId: provider.id })
return []
}
}

View File

@@ -1,164 +0,0 @@
/**
* API Response Schemas for model listing
* Used exclusively by listModels.ts
*
* All object schemas use z.looseObject() to tolerate unknown fields
* from providers — prevents parse failures when APIs add new fields.
*/
import * as z from 'zod'
// === OpenAI-compatible (also used by OpenRouter, PPIO, etc.) ===
export const OpenAIModelsResponseSchema = z.object({
data: z.array(
z.looseObject({
id: z.string(),
object: z.string().optional().default('model'),
created: z.number().optional(),
owned_by: z.string().optional()
})
),
object: z.string().optional()
})
// === Ollama ===
export const OllamaTagsResponseSchema = z.object({
models: z.array(
z.looseObject({
name: z.string(),
model: z.string().optional(),
modified_at: z.string().optional(),
size: z.number().optional(),
digest: z.string().optional(),
details: z
.looseObject({
parent_model: z.string().optional(),
format: z.string().optional(),
family: z.string().optional(),
families: z.array(z.string()).optional(),
parameter_size: z.string().optional(),
quantization_level: z.string().optional()
})
.optional()
})
)
})
// === Gemini ===
export const GeminiModelsResponseSchema = z.object({
models: z.array(
z.looseObject({
name: z.string(),
displayName: z.string().optional(),
description: z.string().optional(),
version: z.string().optional(),
baseModelId: z.string().optional(),
inputTokenLimit: z.number().optional(),
outputTokenLimit: z.number().optional(),
supportedGenerationMethods: z.array(z.string()).optional()
})
),
nextPageToken: z.string().optional()
})
// === GitHub Models ===
export const GitHubModelsResponseSchema = z.array(
z.looseObject({
id: z.string(),
summary: z.string().optional(),
publisher: z.string().optional(),
name: z.string().optional(),
description: z.string().optional(),
version: z.string().optional()
})
)
// === Together ===
export const TogetherModelsResponseSchema = z.array(
z.looseObject({
id: z.string(),
display_name: z.string().optional(),
organization: z.string().optional(),
description: z.string().optional(),
context_length: z.number().optional(),
pricing: z
.looseObject({
input: z.number().optional(),
output: z.number().optional()
})
.optional()
})
)
// === NewAPI (extends OpenAI with endpoint types) ===
export const NewApiModelsResponseSchema = z.object({
data: z.array(
z.looseObject({
id: z.string(),
object: z.string().optional().default('model'),
created: z.number().optional(),
owned_by: z.string().optional(),
supported_endpoint_types: z
.array(z.string())
.nullable()
.optional()
.transform((v) => v ?? undefined)
})
),
object: z.string().optional()
})
// === OVMS (OpenVINO Model Server) ===
export const OVMSConfigResponseSchema = z.record(
z.string(),
z.object({
model_version_status: z
.array(
z.looseObject({
state: z.string(),
status: z
.looseObject({
error_code: z.string().optional(),
error_message: z.string().optional()
})
.optional()
})
)
.optional()
})
)
// === AIHubMix ===
export const AIHubMixModelsResponseSchema = z.object({
data: z.array(
z.looseObject({
model_id: z.string(),
model_name: z.string().optional(),
developer_id: z.number().optional(),
desc: z.string().optional(),
pricing: z
.looseObject({
cache_read: z.number().optional(),
cache_write: z.number().optional(),
input: z.number().optional(),
output: z.number().optional()
})
.optional(),
types: z.string().optional(),
features: z.string().optional(),
input_modalities: z.string().optional(),
endpoints: z.string().optional(),
max_output: z.number().optional(),
context_length: z.number().optional()
})
),
message: z.string().optional(),
success: z.boolean().optional()
})

View File

@@ -1,140 +0,0 @@
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
import type { Assistant, KnowledgeReference } from '@renderer/types'
import type { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract'
import { REFERENCE_PROMPT } from '@shared/config/prompts'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { isEmpty } from 'lodash'
import * as z from 'zod'
/**
* 知识库搜索工具
* 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析
*/
export const knowledgeSearchTool = (
assistant: Assistant,
extractedKeywords: KnowledgeExtractResults,
topicId: string,
userMessage?: string
) => {
return tool({
description: `Knowledge base search tool for retrieving information from user's private knowledge base. This searches your local collection of documents, web content, notes, and other materials you have stored.
This tool has been configured with search parameters based on the conversation context:
- Prepared queries: ${extractedKeywords.question.map((q) => `"${q}"`).join(', ')}
- Query rewrite: "${extractedKeywords.rewrite}"
You can use this tool as-is, or provide additionalContext to refine the search focus within the knowledge base.`,
inputSchema: z.object({
additionalContext: z
.string()
.optional()
.describe('Optional additional context or specific focus to enhance the knowledge search')
}),
execute: async ({ additionalContext }) => {
// try {
// 获取助手的知识库配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
// 检查是否有知识库
if (!hasKnowledgeBase) {
return []
}
let finalQueries = [...extractedKeywords.question]
let finalRewrite = extractedKeywords.rewrite
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
finalRewrite = cleanContext
}
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return []
}
// 构建搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
// 直接模式:使用用户消息内容
const directContent = userMessage || finalQueries[0] || 'search'
searchCriteria = {
question: [directContent],
rewrite: directContent
}
} else {
// 自动模式:使用意图识别的结果
searchCriteria = {
question: finalQueries,
rewrite: finalRewrite
}
}
// 构建 ExtractResults 对象
const extractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
// 执行知识库搜索
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
id: ref.id,
content: ref.content,
sourceUrl: ref.sourceUrl,
type: ref.type,
file: ref.file,
metadata: ref.metadata
}))
// TODO 在工具函数中添加搜索缓存机制
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
// 返回结果
return knowledgeReferencesData
},
toModelOutput: ({ output: results }) => {
let summary = 'No search needed based on the query analysis.'
if (results.length > 0) {
summary = `Found ${results.length} relevant sources. Use [number] format to cite specific information.`
}
const referenceContent = `\`\`\`json\n${JSON.stringify(results, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the knowledge references, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
return {
type: 'content',
value: [
{
type: 'text',
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
]
}
}
})
}
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
export default knowledgeSearchTool

View File

@@ -1,47 +0,0 @@
import { preferenceService } from '@data/PreferenceService'
import store from '@renderer/store'
import { selectMemoryConfig } from '@renderer/store/memory'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import * as z from 'zod'
import { MemoryProcessor } from '../../services/MemoryProcessor'
/**
* 🧠 基础记忆搜索工具
* AI 可以主动调用的简单记忆搜索
*/
export const memorySearchTool = (assistantId: string) => {
return tool({
description: 'Search through conversation memories and stored facts for relevant context',
inputSchema: z.object({
query: z.string().describe('Search query to find relevant memories'),
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
}),
execute: async ({ query, limit = 5 }) => {
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
if (!globalMemoryEnabled) {
return []
}
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) {
return []
}
const currentUserId = await preferenceService.get('feature.memory.current_user_id')
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistantId, currentUserId)
const memoryProcessor = new MemoryProcessor()
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) {
return relevantMemories
}
return []
}
})
}
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>

View File

@@ -1,204 +0,0 @@
import { webSearchService } from '@renderer/services/WebSearchService'
import type { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import type { ExtractResults } from '@renderer/utils/extract'
import { REFERENCE_PROMPT } from '@shared/config/prompts'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import * as z from 'zod'
/**
* 使用预提取关键词的网络搜索工具
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
*/
export const webSearchToolWithPreExtractedKeywords = (
webSearchProviderId: WebSearchProvider['id'],
extractedKeywords: {
question: string[]
links?: string[]
},
requestId: string
) => {
const webSearchProvider = webSearchService.getWebSearchProvider(webSearchProviderId)
return tool({
description: `Web search tool for finding current information, news, and real-time data from the internet.
This tool has been configured with search parameters based on the conversation context:
- Prepared queries: ${extractedKeywords.question.map((q) => `"${q}"`).join(', ')}${
extractedKeywords.links?.length
? `
- Relevant URLs: ${extractedKeywords.links.join(', ')}`
: ''
}
You can use this tool as-is to search with the prepared queries, or provide additionalContext to refine or replace the search terms.`,
inputSchema: z.object({
additionalContext: z
.string()
.optional()
.describe('Optional additional context, keywords, or specific focus to enhance the search')
}),
execute: async ({ additionalContext }) => {
let finalQueries = [...extractedKeywords.question]
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
}
}
let searchResults: WebSearchProviderResponse = {
query: '',
results: []
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return searchResults
}
// 构建 ExtractResults 结构用于 processWebsearch
const extractResults: ExtractResults = {
websearch: {
question: finalQueries,
links: extractedKeywords.links
}
}
searchResults = await webSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
return searchResults
},
toModelOutput: ({ output: results }) => {
let summary = 'No search needed based on the query analysis.'
if (results.query && results.results.length > 0) {
summary = `Found ${results.results.length} relevant sources. Use [number] format to cite specific information.`
}
const citationData = results.results.map((result, index) => ({
number: index + 1,
title: result.title,
content: result.content,
url: result.url
}))
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the search results, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
return {
type: 'content',
value: [
{
type: 'text',
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
]
}
}
})
}
// export const webSearchToolWithExtraction = (
// webSearchProviderId: WebSearchProvider['id'],
// requestId: string,
// assistant: Assistant
// ) => {
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
// return tool({
// name: 'web_search_with_extraction',
// description: 'Search the web for information with automatic keyword extraction from user messages',
// inputSchema: z.object({
// userMessage: z.object({
// content: z.string().describe('The main content of the message'),
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
// }),
// lastAnswer: z.object({
// content: z.string().describe('The main content of the message'),
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
// })
// }),
// outputSchema: z.object({
// extractedKeywords: z.object({
// question: z.array(z.string()),
// links: z.array(z.string()).optional()
// }),
// searchResults: z.array(
// z.object({
// query: z.string(),
// results: WebSearchProviderResult
// })
// )
// }),
// execute: async ({ userMessage, lastAnswer }) => {
// const lastUserMessage: Message = {
// id: requestId,
// role: userMessage.role,
// assistantId: assistant.id,
// topicId: 'temp',
// createdAt: new Date().toISOString(),
// status: UserMessageStatus.SUCCESS,
// blocks: []
// }
// const lastAnswerMessage: Message | undefined = lastAnswer
// ? {
// id: requestId + '_answer',
// role: lastAnswer.role,
// assistantId: assistant.id,
// topicId: 'temp',
// createdAt: new Date().toISOString(),
// status: UserMessageStatus.SUCCESS,
// blocks: []
// }
// : undefined
// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
// shouldWebSearch: true,
// shouldKnowledgeSearch: false,
// lastAnswer: lastAnswerMessage
// })
// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
// return 'No search needed or extraction failed'
// }
// const searchQueries = extractResults.websearch.question
// const searchResults: Array<{ query: string; results: any }> = []
// for (const query of searchQueries) {
// // 构建单个查询的ExtractResults结构
// const queryExtractResults: ExtractResults = {
// websearch: {
// question: [query],
// links: extractResults.websearch.links
// }
// }
// const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
// searchResults.push({
// query,
// results: response
// })
// }
// return { extractedKeywords: extractResults.websearch, searchResults }
// }
// })
// }
// export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>

View File

@@ -1,656 +0,0 @@
/**
* AI SDK Span Adapter
*
* 将 AI SDK 的 telemetry 数据转换为现有的 SpanEntity 格式
* 注意 AI SDK 的层级结构ai.xxx 是一个层级ai.xxx.xxx 是对应层级下的子集
*/
import { loggerService } from '@logger'
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
import type { Span } from '@opentelemetry/api'
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
const logger = loggerService.withContext('AiSdkSpanAdapter')
export interface AiSdkSpanData {
span: Span
topicId?: string
modelName?: string
}
// 扩展接口用于访问span的内部数据
interface SpanWithInternals extends Span {
_spanProcessor?: any
_attributes?: Record<string, any>
_events?: any[]
name?: string
startTime?: [number, number]
endTime?: [number, number] | null
status?: { code: SpanStatusCode; message?: string }
kind?: SpanKind
ended?: boolean
parentSpanId?: string
links?: any[]
}
export class AiSdkSpanAdapter {
/**
* 将 AI SDK span 转换为 SpanEntity 格式
*/
static convertToSpanEntity(spanData: AiSdkSpanData): SpanEntity {
const { span, topicId, modelName } = spanData
const spanContext = span.spanContext()
// 尝试从不同方式获取span数据
const spanWithInternals = span as SpanWithInternals
let attributes: Record<string, any> = {}
let events: any[] = []
let spanName = 'unknown'
let spanStatus = { code: SpanStatusCode.UNSET }
let spanKind = SpanKind.INTERNAL
let startTime: [number, number] = [0, 0]
let endTime: [number, number] | null = null
let ended = false
let parentSpanId = ''
let links: any[] = []
// 详细记录span的结构信息用于调试
logger.debug('Debugging span structure', {
hasInternalAttributes: !!spanWithInternals._attributes,
hasGetAttributes: typeof (span as any).getAttributes === 'function',
spanKeys: Object.keys(span),
spanInternalKeys: Object.keys(spanWithInternals),
spanContext: span.spanContext(),
// 尝试获取所有可能的属性路径
attributesPath1: spanWithInternals._attributes,
attributesPath2: (span as any).attributes,
attributesPath3: (span as any)._spanData?.attributes,
attributesPath4: (span as any).resource?.attributes
})
// 尝试多种方式获取attributes
if (spanWithInternals._attributes) {
attributes = spanWithInternals._attributes
logger.debug('Found attributes via _attributes', { attributeCount: Object.keys(attributes).length })
} else if (typeof (span as any).getAttributes === 'function') {
attributes = (span as any).getAttributes()
logger.debug('Found attributes via getAttributes()', { attributeCount: Object.keys(attributes).length })
} else if ((span as any).attributes) {
attributes = (span as any).attributes
logger.debug('Found attributes via direct attributes property', {
attributeCount: Object.keys(attributes).length
})
} else if ((span as any)._spanData?.attributes) {
attributes = (span as any)._spanData.attributes
logger.debug('Found attributes via _spanData.attributes', { attributeCount: Object.keys(attributes).length })
} else {
// 尝试从span的其他属性获取
logger.warn('无法获取span attributes尝试备用方法', {
availableKeys: Object.keys(span),
spanType: span.constructor.name
})
}
// 获取其他属性
if (spanWithInternals._events) {
events = spanWithInternals._events
}
if (spanWithInternals.name) {
spanName = spanWithInternals.name
}
if (spanWithInternals.status) {
spanStatus = spanWithInternals.status
}
if (spanWithInternals.kind !== undefined) {
spanKind = spanWithInternals.kind
}
if (spanWithInternals.startTime) {
startTime = spanWithInternals.startTime
}
if (spanWithInternals.endTime) {
endTime = spanWithInternals.endTime
}
if (spanWithInternals.ended !== undefined) {
ended = spanWithInternals.ended
}
if (spanWithInternals.parentSpanId) {
parentSpanId = spanWithInternals.parentSpanId
}
// 兜底:尝试从 attributes 中读取我们注入的父信息
if (!parentSpanId && attributes['trace.parentSpanId']) {
parentSpanId = attributes['trace.parentSpanId']
}
if (spanWithInternals.links) {
links = spanWithInternals.links
}
// 提取 AI SDK 特有的数据
const tokenUsage = this.extractTokenUsage(attributes)
const { inputs, outputs } = this.extractInputsOutputs(attributes)
const formattedSpanName = this.formatSpanName(spanName)
const spanTag = this.extractSpanTag(spanName, attributes)
const typeSpecificData = this.extractSpanTypeSpecificData(attributes)
// 详细记录转换过程
const operationId = attributes['ai.operationId']
logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName,
operationId,
spanTag,
hasTokenUsage: !!tokenUsage,
hasInputs: !!inputs,
hasOutputs: !!outputs,
hasTypeSpecificData: Object.keys(typeSpecificData).length > 0,
attributesCount: Object.keys(attributes).length,
topicId,
modelName,
spanId: spanContext.spanId,
traceId: spanContext.traceId
})
if (tokenUsage) {
logger.debug('Token usage data found', {
spanName: spanName,
operationId,
usage: tokenUsage,
spanId: spanContext.spanId
})
}
if (inputs || outputs) {
logger.debug('Input/Output data extracted', {
spanName: spanName,
operationId,
hasInputs: !!inputs,
hasOutputs: !!outputs,
inputKeys: inputs ? Object.keys(inputs) : [],
outputKeys: outputs ? Object.keys(outputs) : [],
spanId: spanContext.spanId
})
}
if (Object.keys(typeSpecificData).length > 0) {
logger.debug('Type-specific data extracted', {
spanName: spanName,
operationId,
typeSpecificKeys: Object.keys(typeSpecificData),
spanId: spanContext.spanId
})
}
// 转换为 SpanEntity 格式
const spanEntity: SpanEntity = {
id: spanContext.spanId,
name: formattedSpanName,
parentId: parentSpanId,
traceId: spanContext.traceId,
status: this.convertSpanStatus(spanStatus.code),
kind: this.convertSpanKind(spanKind),
attributes: {
...this.filterRelevantAttributes(attributes),
...typeSpecificData,
inputs: inputs,
outputs: outputs,
tags: spanTag,
modelName: modelName || this.extractModelFromAttributes(attributes) || ''
},
isEnd: ended,
events: events,
startTime: this.convertTimestamp(startTime),
endTime: endTime ? this.convertTimestamp(endTime) : null,
links: links,
topicId: topicId,
usage: tokenUsage,
modelName: modelName || this.extractModelFromAttributes(attributes)
}
logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName,
operationId,
spanId: spanContext.spanId,
traceId: spanContext.traceId,
topicId,
entityId: spanEntity.id,
hasUsage: !!spanEntity.usage,
status: spanEntity.status,
tags: spanEntity.attributes?.tags
})
return spanEntity
}
/**
* 从 AI SDK attributes 中提取 token usage
* 支持多种格式:
* - AI SDK 标准格式: ai.usage.completionTokens, ai.usage.promptTokens
* - 完整usage对象格式: ai.usage (JSON字符串或对象)
*/
private static extractTokenUsage(attributes: Record<string, any>): TokenUsage | undefined {
logger.debug('Extracting token usage from attributes', {
attributeKeys: Object.keys(attributes),
usageRelatedKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token')),
fullAttributes: attributes
})
const inputsTokenKeys = [
// base span
'ai.usage.promptTokens',
// LLM span
'gen_ai.usage.input_tokens'
]
const outputTokenKeys = [
// base span
'ai.usage.completionTokens',
// LLM span
'gen_ai.usage.output_tokens'
]
const promptTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
const completionTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
if (completionTokens !== undefined || promptTokens !== undefined) {
const usage: TokenUsage = {
prompt_tokens: Number(promptTokens) || 0,
completion_tokens: Number(completionTokens) || 0,
total_tokens: (Number(promptTokens) || 0) + (Number(completionTokens) || 0)
}
logger.debug('Extracted token usage from AI SDK standard attributes', {
usage,
foundStandardAttributes: {
'ai.usage.completionTokens': completionTokens,
'ai.usage.promptTokens': promptTokens
}
})
return usage
}
// 对于不包含token usage的spans如tool calls这是正常的
logger.debug('No token usage found in span attributes (normal for tool calls)', {
availableKeys: Object.keys(attributes),
usageKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token'))
})
return undefined
}
/**
* 从 AI SDK attributes 中提取 inputs 和 outputs
* 根据AI SDK文档按不同span类型精确映射
*/
private static extractInputsOutputs(attributes: Record<string, any>): { inputs?: any; outputs?: any } {
const operationId = attributes['ai.operationId']
let inputs: any = undefined
let outputs: any = undefined
logger.debug('Extracting inputs/outputs by operation type', {
operationId,
availableKeys: Object.keys(attributes).filter(
(key) => key.includes('prompt') || key.includes('response') || key.includes('toolCall')
)
})
// 根据AI SDK文档按操作类型提取数据
switch (operationId) {
case 'ai.generateText':
case 'ai.streamText':
// 顶层LLM spans: ai.prompt 包含输入
inputs = {
prompt: this.parseAttributeValue(attributes['ai.prompt'])
}
outputs = this.extractLLMOutputs(attributes)
break
case 'ai.generateText.doGenerate':
case 'ai.streamText.doStream':
// Provider spans: ai.prompt.messages 包含详细输入
inputs = {
messages: this.parseAttributeValue(attributes['ai.prompt.messages']),
tools: this.parseAttributeValue(attributes['ai.prompt.tools']),
toolChoice: this.parseAttributeValue(attributes['ai.prompt.toolChoice'])
}
outputs = this.extractProviderOutputs(attributes)
break
case 'ai.toolCall':
// Tool call spans: 工具参数和结果
inputs = {
toolName: attributes['ai.toolCall.name'],
toolId: attributes['ai.toolCall.id'],
args: this.parseAttributeValue(attributes['ai.toolCall.args'])
}
outputs = {
result: this.parseAttributeValue(attributes['ai.toolCall.result'])
}
break
default:
// 回退到通用逻辑
inputs = this.extractGenericInputs(attributes)
outputs = this.extractGenericOutputs(attributes)
break
}
logger.debug('Extracted inputs/outputs', {
operationId,
hasInputs: !!inputs,
hasOutputs: !!outputs,
inputKeys: inputs ? Object.keys(inputs) : [],
outputKeys: outputs ? Object.keys(outputs) : []
})
return { inputs, outputs }
}
/**
* 提取LLM顶层spans的输出
*/
private static extractLLMOutputs(attributes: Record<string, any>): any {
const outputs: any = {}
if (attributes['ai.response.text']) {
outputs.text = attributes['ai.response.text']
}
if (attributes['ai.response.toolCalls']) {
outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls'])
}
if (attributes['ai.response.finishReason']) {
outputs.finishReason = attributes['ai.response.finishReason']
}
if (attributes['ai.settings.maxOutputTokens']) {
outputs.maxOutputTokens = attributes['ai.settings.maxOutputTokens']
}
return Object.keys(outputs).length > 0 ? outputs : undefined
}
/**
* 提取Provider spans的输出
*/
private static extractProviderOutputs(attributes: Record<string, any>): any {
const outputs: any = {}
if (attributes['ai.response.text']) {
outputs.text = attributes['ai.response.text']
}
if (attributes['ai.response.toolCalls']) {
outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls'])
}
if (attributes['ai.response.finishReason']) {
outputs.finishReason = attributes['ai.response.finishReason']
}
// doStream特有的性能指标
if (attributes['ai.response.msToFirstChunk']) {
outputs.msToFirstChunk = attributes['ai.response.msToFirstChunk']
}
if (attributes['ai.response.msToFinish']) {
outputs.msToFinish = attributes['ai.response.msToFinish']
}
if (attributes['ai.response.avgCompletionTokensPerSecond']) {
outputs.avgCompletionTokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond']
}
return Object.keys(outputs).length > 0 ? outputs : undefined
}
/**
* 通用输入提取(回退逻辑)
*/
private static extractGenericInputs(attributes: Record<string, any>): any {
const inputKeys = ['ai.prompt', 'ai.prompt.messages', 'ai.request', 'inputs']
for (const key of inputKeys) {
if (attributes[key]) {
return this.parseAttributeValue(attributes[key])
}
}
return undefined
}
/**
* 通用输出提取(回退逻辑)
*/
private static extractGenericOutputs(attributes: Record<string, any>): any {
const outputKeys = ['ai.response.text', 'ai.response', 'ai.output', 'outputs']
for (const key of outputKeys) {
if (attributes[key]) {
return this.parseAttributeValue(attributes[key])
}
}
return undefined
}
/**
* 解析属性值,处理字符串化的 JSON
*/
private static parseAttributeValue(value: any): any {
if (typeof value === 'string') {
try {
return JSON.parse(value)
} catch (e) {
return value
}
}
return value
}
/**
* 格式化 span 名称,处理 AI SDK 的层级结构
*/
private static formatSpanName(name: string): string {
// AI SDK 的 span 名称可能是 ai.generateText, ai.streamText.doStream 等
// 保持原始名称,但可以添加一些格式化逻辑
if (name.startsWith('ai.')) {
return name
}
return name
}
/**
* 从AI SDK operationId中提取精确的span标签
*/
private static extractSpanTag(name: string, attributes: Record<string, any>): string {
const operationId = attributes['ai.operationId']
logger.debug('Extracting span tag', {
spanName: name,
operationId,
operationName: attributes['operation.name']
})
// 根据AI SDK文档的operationId精确分类
switch (operationId) {
case 'ai.generateText':
return 'LLM-GENERATE'
case 'ai.streamText':
return 'LLM-STREAM'
case 'ai.generateText.doGenerate':
return 'PROVIDER-GENERATE'
case 'ai.streamText.doStream':
return 'PROVIDER-STREAM'
case 'ai.toolCall':
return 'TOOL-CALL'
case 'ai.generateImage':
return 'IMAGE'
case 'ai.embed':
return 'EMBEDDING'
default:
// 回退逻辑基于span名称
if (name.includes('generateText') || name.includes('streamText')) {
return 'LLM'
}
if (name.includes('generateImage')) {
return 'IMAGE'
}
if (name.includes('embed')) {
return 'EMBEDDING'
}
if (name.includes('toolCall')) {
return 'TOOL'
}
// 最终回退
return attributes['ai.operationType'] || attributes['operation.type'] || 'AI_SDK'
}
}
/**
* 根据span类型提取特定的额外数据
*/
private static extractSpanTypeSpecificData(attributes: Record<string, any>): Record<string, any> {
const operationId = attributes['ai.operationId']
const specificData: Record<string, any> = {}
switch (operationId) {
case 'ai.generateText':
case 'ai.streamText':
// LLM顶层spans的特定数据
if (attributes['ai.settings.maxOutputTokens']) {
specificData.maxOutputTokens = attributes['ai.settings.maxOutputTokens']
}
if (attributes['resource.name']) {
specificData.functionId = attributes['resource.name']
}
break
case 'ai.generateText.doGenerate':
case 'ai.streamText.doStream':
// Provider spans的特定数据
if (attributes['ai.model.id']) {
specificData.providerId = attributes['ai.model.provider'] || 'unknown'
specificData.modelId = attributes['ai.model.id']
}
// doStream特有的性能数据
if (operationId === 'ai.streamText.doStream') {
if (attributes['ai.response.msToFirstChunk']) {
specificData.msToFirstChunk = attributes['ai.response.msToFirstChunk']
}
if (attributes['ai.response.msToFinish']) {
specificData.msToFinish = attributes['ai.response.msToFinish']
}
if (attributes['ai.response.avgCompletionTokensPerSecond']) {
specificData.tokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond']
}
}
break
case 'ai.toolCall':
// Tool call spans的特定数据
specificData.toolName = attributes['ai.toolCall.name']
specificData.toolId = attributes['ai.toolCall.id']
// 根据文档tool call可能有不同的操作类型
if (attributes['operation.name']) {
specificData.operationName = attributes['operation.name']
}
break
default:
// 通用的AI SDK属性
if (attributes['ai.telemetry.functionId']) {
specificData.telemetryFunctionId = attributes['ai.telemetry.functionId']
}
if (attributes['ai.telemetry.metadata']) {
specificData.telemetryMetadata = this.parseAttributeValue(attributes['ai.telemetry.metadata'])
}
break
}
// 添加通用的操作标识
if (operationId) {
specificData.operationType = operationId
}
if (attributes['operation.name']) {
specificData.operationName = attributes['operation.name']
}
logger.debug('Extracted type-specific data', {
operationId,
specificDataKeys: Object.keys(specificData),
specificData
})
return specificData
}
/**
* 从属性中提取模型名称
*/
private static extractModelFromAttributes(attributes: Record<string, any>): string | undefined {
return (
attributes['ai.model.id'] ||
attributes['ai.model'] ||
attributes['model.id'] ||
attributes['model'] ||
attributes['modelName']
)
}
/**
* 过滤相关的属性,移除不需要的系统属性
*/
private static filterRelevantAttributes(attributes: Record<string, any>): Record<string, any> {
const filtered: Record<string, any> = {}
// 保留有用的属性,过滤掉已经单独处理的属性
const excludeKeys = ['ai.usage', 'ai.prompt', 'ai.response', 'ai.input', 'ai.output', 'inputs', 'outputs']
Object.entries(attributes).forEach(([key, value]) => {
if (!excludeKeys.includes(key)) {
filtered[key] = value
}
})
return filtered
}
/**
* 转换 span 状态
*/
private static convertSpanStatus(statusCode?: SpanStatusCode): string {
switch (statusCode) {
case SpanStatusCode.OK:
return 'OK'
case SpanStatusCode.ERROR:
return 'ERROR'
case SpanStatusCode.UNSET:
default:
return 'UNSET'
}
}
/**
* 转换 span 类型
*/
private static convertSpanKind(kind?: SpanKind): string {
switch (kind) {
case SpanKind.INTERNAL:
return 'INTERNAL'
case SpanKind.CLIENT:
return 'CLIENT'
case SpanKind.SERVER:
return 'SERVER'
case SpanKind.PRODUCER:
return 'PRODUCER'
case SpanKind.CONSUMER:
return 'CONSUMER'
default:
return 'INTERNAL'
}
}
/**
* 转换时间戳格式
*/
private static convertTimestamp(timestamp: [number, number] | number): number {
if (Array.isArray(timestamp)) {
// OpenTelemetry 高精度时间戳 [seconds, nanoseconds]
return timestamp[0] * 1000 + timestamp[1] / 1000000
}
return timestamp
}
}

View File

@@ -1,53 +0,0 @@
import type { Span } from '@opentelemetry/api'
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
import { describe, expect, it, vi } from 'vitest'
import { AiSdkSpanAdapter } from '../AiSdkSpanAdapter'
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
info: vi.fn(),
warn: vi.fn()
})
}
}))
describe('AiSdkSpanAdapter', () => {
const createMockSpan = (attributes: Record<string, unknown>): Span => {
const span = {
spanContext: () => ({
traceId: 'trace-id',
spanId: 'span-id'
}),
_attributes: attributes,
_events: [],
name: 'test span',
status: { code: SpanStatusCode.OK },
kind: SpanKind.CLIENT,
startTime: [0, 0] as [number, number],
endTime: [0, 1] as [number, number],
ended: true,
parentSpanId: '',
links: []
}
return span as unknown as Span
}
it('maps prompt and completion usage tokens to the correct fields', () => {
const attributes = {
'ai.usage.promptTokens': 321,
'ai.usage.completionTokens': 654
}
const span = createMockSpan(attributes)
const result = AiSdkSpanAdapter.convertToSpanEntity({ span })
expect(result.usage).toBeDefined()
expect(result.usage?.prompt_tokens).toBe(321)
expect(result.usage?.completion_tokens).toBe(654)
expect(result.usage?.total_tokens).toBe(975)
})
})

View File

@@ -1,156 +0,0 @@
/**
* Type Tests for Merged Provider Types
*
* These tests validate that the auto-extraction and merging of provider types works correctly.
* They use type-level assertions to ensure compile-time type safety.
*/
import { describe, expectTypeOf, it } from 'vitest'
import type { AppProviderId, AppProviderSettingsMap } from '../merged'
import { appProviderIds } from '../merged'
describe('Unified Provider Types', () => {
describe('appProviderIds literal access', () => {
it('should return canonical IDs with literal types', () => {
// 别名 → 基础名
expectTypeOf(appProviderIds.vertexai).toEqualTypeOf<'google-vertex'>()
// 变体 → 自身(自反映射)
expectTypeOf(appProviderIds['openai-chat']).toEqualTypeOf<'openai-chat'>()
})
})
describe('AppProviderId - All Providers', () => {
it('should include all core extension names', () => {
type Check1 = 'openai' extends AppProviderId ? true : false
type Check2 = 'anthropic' extends AppProviderId ? true : false
type Check3 = 'google' extends AppProviderId ? true : false
type Check4 = 'azure' extends AppProviderId ? true : false
type Check5 = 'deepseek' extends AppProviderId ? true : false
type Check6 = 'xai' extends AppProviderId ? true : false
expectTypeOf<Check1>().toEqualTypeOf<true>()
expectTypeOf<Check2>().toEqualTypeOf<true>()
expectTypeOf<Check3>().toEqualTypeOf<true>()
expectTypeOf<Check4>().toEqualTypeOf<true>()
expectTypeOf<Check5>().toEqualTypeOf<true>()
expectTypeOf<Check6>().toEqualTypeOf<true>()
})
it('should include all project extension names', () => {
type Check1 = 'google-vertex' extends AppProviderId ? true : false
type Check2 = 'bedrock' extends AppProviderId ? true : false
type Check3 = 'github-copilot-openai-compatible' extends AppProviderId ? true : false
type Check4 = 'perplexity' extends AppProviderId ? true : false
type Check5 = 'mistral' extends AppProviderId ? true : false
type Check6 = 'huggingface' extends AppProviderId ? true : false
type Check7 = 'gateway' extends AppProviderId ? true : false
type Check8 = 'cerebras' extends AppProviderId ? true : false
type Check9 = 'ollama' extends AppProviderId ? true : false
expectTypeOf<Check1>().toEqualTypeOf<true>()
expectTypeOf<Check2>().toEqualTypeOf<true>()
expectTypeOf<Check3>().toEqualTypeOf<true>()
expectTypeOf<Check4>().toEqualTypeOf<true>()
expectTypeOf<Check5>().toEqualTypeOf<true>()
expectTypeOf<Check6>().toEqualTypeOf<true>()
expectTypeOf<Check7>().toEqualTypeOf<true>()
expectTypeOf<Check8>().toEqualTypeOf<true>()
expectTypeOf<Check9>().toEqualTypeOf<true>()
})
it('should include all aliases (core + project)', () => {
// Core aliases
type Check2 = 'claude' extends AppProviderId ? true : false
// Project aliases
type Check3 = 'vertexai' extends AppProviderId ? true : false
type Check4 = 'aws-bedrock' extends AppProviderId ? true : false
type Check5 = 'copilot' extends AppProviderId ? true : false
type Check6 = 'github-copilot' extends AppProviderId ? true : false
type Check7 = 'hf' extends AppProviderId ? true : false
type Check8 = 'hugging-face' extends AppProviderId ? true : false
type Check9 = 'ai-gateway' extends AppProviderId ? true : false
expectTypeOf<Check2>().toEqualTypeOf<true>()
expectTypeOf<Check3>().toEqualTypeOf<true>()
expectTypeOf<Check4>().toEqualTypeOf<true>()
expectTypeOf<Check5>().toEqualTypeOf<true>()
expectTypeOf<Check6>().toEqualTypeOf<true>()
expectTypeOf<Check7>().toEqualTypeOf<true>()
expectTypeOf<Check8>().toEqualTypeOf<true>()
expectTypeOf<Check9>().toEqualTypeOf<true>()
})
})
describe('AppProviderId', () => {
it('should merge core and project IDs', () => {
// Core providers
type Check1 = 'openai' extends AppProviderId ? true : false
type Check2 = 'anthropic' extends AppProviderId ? true : false
type Check3 = 'google' extends AppProviderId ? true : false
type Check4 = 'azure' extends AppProviderId ? true : false
type Check5 = 'xai' extends AppProviderId ? true : false
// Project providers
type Check6 = 'google-vertex' extends AppProviderId ? true : false
type Check7 = 'bedrock' extends AppProviderId ? true : false
type Check8 = 'ollama' extends AppProviderId ? true : false
expectTypeOf<Check1>().toEqualTypeOf<true>()
expectTypeOf<Check2>().toEqualTypeOf<true>()
expectTypeOf<Check3>().toEqualTypeOf<true>()
expectTypeOf<Check4>().toEqualTypeOf<true>()
expectTypeOf<Check5>().toEqualTypeOf<true>()
expectTypeOf<Check6>().toEqualTypeOf<true>()
expectTypeOf<Check7>().toEqualTypeOf<true>()
expectTypeOf<Check8>().toEqualTypeOf<true>()
})
it('should accept string for dynamic providers', () => {
type Check = string extends AppProviderId ? true : false
expectTypeOf<Check>().toEqualTypeOf<true>()
})
})
describe('AppProviderSettingsMap', () => {
it('should map core provider IDs to their settings', () => {
// OpenAI settings should have OpenAI-specific fields
type OpenAISettings = AppProviderSettingsMap['openai']
type HasBaseURL = 'baseURL' extends keyof OpenAISettings ? true : false
type HasApiKey = 'apiKey' extends keyof OpenAISettings ? true : false
expectTypeOf<HasBaseURL>().toEqualTypeOf<true>()
expectTypeOf<HasApiKey>().toEqualTypeOf<true>()
})
it('should map project provider IDs to their settings', () => {
// Project providers should have settings
type VertexSettings = AppProviderSettingsMap['google-vertex']
type BedrockSettings = AppProviderSettingsMap['bedrock']
type OllamaSettings = AppProviderSettingsMap['ollama']
// These should not be never
type VertexNotNever = [VertexSettings] extends [never] ? false : true
type BedrockNotNever = [BedrockSettings] extends [never] ? false : true
type OllamaNotNever = [OllamaSettings] extends [never] ? false : true
expectTypeOf<VertexNotNever>().toEqualTypeOf<true>()
expectTypeOf<BedrockNotNever>().toEqualTypeOf<true>()
expectTypeOf<OllamaNotNever>().toEqualTypeOf<true>()
})
it('should map aliases to same settings as main ID', () => {
type OpenRouterByName = AppProviderSettingsMap['openrouter']
type OpenRouterByAlias = AppProviderSettingsMap['tokenflux']
expectTypeOf<OpenRouterByName>().toEqualTypeOf<OpenRouterByAlias>()
// Vertex AI aliases should have the same settings
type VertexByName = AppProviderSettingsMap['google-vertex']
type VertexByAlias = AppProviderSettingsMap['vertexai']
expectTypeOf<VertexByName>().toEqualTypeOf<VertexByAlias>()
})
})
})

View File

@@ -1,76 +0,0 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { StringKeys } from '@cherrystudio/ai-core/provider'
import type { AppProviderSettingsMap, AppRuntimeConfig } from './merged'
/**
* Provider 配置
* 基于 RuntimeConfig用于构建 provider 实例的基础配置
*/
export type ProviderConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> = Omit<
AppRuntimeConfig<T>,
'plugins' | 'provider'
> & {
/**
* API endpoint path extracted from baseURL
* Used for identifying image generation endpoints and other special cases
* @example 'chat/completions', 'images/generations', 'predict'
*/
endpoint?: string
}
export type { AppProviderId, AppProviderSettingsMap, AppRuntimeConfig } from './merged'
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'
/**
* Model capability flags computed from model properties and assistant settings.
* Used by provider-specific option builders to decide which parameters to include.
*/
export interface ProviderCapabilities {
/**
* Whether reasoning/thinking parameters should be sent to the provider.
*
* True when the model supports reasoning control (thinking token or reasoning effort)
* AND the user has configured `reasoning_effort` (not `undefined`),
* or when the model is a fixed reasoning model (e.g. DeepSeek R1).
*
* Note: This can be `true` even when `reasoning_effort` is `'none'` — in that case,
* providers should explicitly disable thinking (e.g. Ollama sets `think: false`).
*/
enableReasoning: boolean
/**
* Whether provider-native web search should be enabled.
* True when no external search provider is configured AND the model supports built-in web search.
*/
enableWebSearch: boolean
/**
* Whether the model should generate images inline.
* True when the model supports image generation AND the assistant has it enabled.
*/
enableGenerateImage: boolean
/**
* Whether provider-native URL context should be enabled.
* True when the assistant has it enabled, the provider supports it,
* and the model is compatible (currently Gemini or Anthropic).
*/
enableUrlContext: boolean
}
/**
* Result of completions operation
* Simple interface with getText method to retrieve the generated text
*/
export type CompletionsResult = {
getText: () => string
usage?: any
}

View File

@@ -1,101 +0,0 @@
/**
* Application-Level Provider Type Merge Point
*/
import type { RuntimeConfig } from '@cherrystudio/ai-core/core'
import type { ModelConfig } from '@cherrystudio/ai-core/core/models/types'
import type { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime'
import type {
ExtensionConfigToIdResolutionMap,
ExtensionToSettingsMap,
ExtractProviderIds,
ProviderExtensionConfig,
StringKeys,
UnionToIntersection
} from '@cherrystudio/ai-core/provider'
import { coreExtensions } from '@cherrystudio/ai-core/provider'
import { extensions } from '../provider/extensions'
/**
* All provider extensions merged into one array
*/
const allExtensions = [...coreExtensions, ...extensions] as const
type AllExtensionConfigs = (typeof allExtensions)[number]['config']
// ==================== Unified Application Types ====================
/**
* Complete Application Provider ID Type
*/
type KnownAppProviderId = ExtractProviderIds<AllExtensionConfigs>
export type AppProviderId = KnownAppProviderId | (string & {})
/**
* Application Provider Settings Map
* 使用 UnionToIntersection 将所有 extension 的 settings map 合并为单一对象类型
*/
export type AppProviderSettingsMap = UnionToIntersection<ExtensionToSettingsMap<(typeof allExtensions)[number]>>
// ==================== Runtime Utilities ====================
/**
* Check if a provider ID belongs to the registered extensions
*/
export function isRegisteredProviderId(id: string): boolean {
return allExtensions.some((ext) => ext.hasProviderId(id))
}
/**
* Get all registered provider IDs (for debugging/logging)
*/
export function getAllProviderIds(): string[] {
return allExtensions.flatMap((ext) => ext.getProviderIds())
}
type ProviderIdsMap = UnionToIntersection<ExtensionConfigToIdResolutionMap<AllExtensionConfigs>>
/**
* 应用层 Provider IDs 常量
*/
function buildAppProviderIds(): ProviderIdsMap {
const map = {} as ProviderIdsMap
allExtensions.forEach((ext) => {
const config = ext.config as ProviderExtensionConfig<any, any, KnownAppProviderId>
const name = config.name
;(map as Record<string, KnownAppProviderId>)[name] = name
if (config.aliases) {
config.aliases.forEach((alias) => {
;(map as Record<string, KnownAppProviderId>)[alias] = name
})
}
if (config.variants) {
config.variants.forEach((variant) => {
// 变体自反映射:'azure-responses' -> 'azure-responses'
const variantId = `${name}-${variant.suffix}` as KnownAppProviderId
;(map as Record<string, KnownAppProviderId>)[variantId] = variantId
})
}
})
return map
}
export const appProviderIds = buildAppProviderIds()
export type AppModelConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> =
ModelConfig<T, AppProviderSettingsMap>
/**
* 应用层运行时配置 - 支持完整的 App provider IDs 和 settings
*/
export type AppRuntimeConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> =
RuntimeConfig<AppProviderSettingsMap, T>
/**
* 应用层运行时执行器 - 支持完整的 App provider IDs 和 settings
*/
export type AppRuntimeExecutor = RuntimeExecutor<AppProviderSettingsMap>

View File

@@ -1,29 +0,0 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import type { MCPTool } from '@renderer/types'
import type { Assistant, Message } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
/**
* AI SDK 中间件配置项(用于插件构建)
*
* 注意provider 和 model 不在此接口中。
* 它们是 AiProvider 的固有属性(构造时确定),
* 由 AiProvider 内部注入到 buildPlugins避免调用方遗漏。
*/
export interface AiSdkMiddlewareConfig {
streamOutput: boolean
onChunk?: (chunk: Chunk) => void
assistant?: Assistant
enableReasoning: boolean
isPromptToolUse: boolean
isSupportedToolUse: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
enableUrlContext: boolean
mcpTools?: MCPTool[]
uiMessages?: Message[]
webSearchPluginConfig?: WebSearchPluginConfig
urlContextConfig?: Record<string, any>
knowledgeRecognition?: 'off' | 'on'
mcpMode?: string
}

View File

@@ -1,656 +0,0 @@
/**
* extractAiSdkStandardParams Unit Tests
* Tests for extracting AI SDK standard parameters from custom parameters
*/
import { describe, expect, it, vi } from 'vitest'
import { extractAiSdkStandardParams } from '../options'
// Mock logger to prevent errors
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
// Mock settings store
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
// Mock hooks to prevent uuid errors
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(() => ({}))
}))
// Mock uuid to prevent errors
vi.mock('uuid', () => ({
v4: vi.fn(() => 'test-uuid')
}))
// Mock AssistantService to prevent uuid errors
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
})),
getDefaultTopic: vi.fn(() => ({
id: 'test-topic',
assistantId: 'test-assistant',
createdAt: new Date().toISOString()
}))
}))
// Mock provider service
vi.mock('@renderer/services/ProviderService', () => ({
getProviderById: vi.fn(() => ({
id: 'test-provider',
name: 'Test Provider'
}))
}))
// Mock config modules
vi.mock('@renderer/config/models', async (importOriginal) => {
const actual: any = await importOriginal()
return {
...actual,
isOpenAIModel: vi.fn(() => false),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => false),
isSupportVerbosityModel: vi.fn(() => false),
getModelSupportedVerbosity: vi.fn(() => [])
}
})
vi.mock('@renderer/config/translate', () => ({
mapLanguageToQwenMTModel: vi.fn()
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportServiceTierProvider: vi.fn(() => false),
isSupportVerbosityProvider: vi.fn(() => false)
}))
describe('extractAiSdkStandardParams', () => {
describe('Positive cases - Standard parameters extraction', () => {
it('should extract all AI SDK standard parameters', () => {
const customParams = {
maxOutputTokens: 1000,
temperature: 0.7,
topP: 0.9,
topK: 40,
presencePenalty: 0.5,
frequencyPenalty: 0.3,
stopSequences: ['STOP', 'END'],
seed: 42
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 1000,
temperature: 0.7,
topP: 0.9,
topK: 40,
presencePenalty: 0.5,
frequencyPenalty: 0.3,
stopSequences: ['STOP', 'END'],
seed: 42
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract single standard parameter', () => {
const customParams = {
temperature: 0.8
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.8
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract topK parameter', () => {
const customParams = {
topK: 50
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topK: 50
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract frequencyPenalty parameter', () => {
const customParams = {
frequencyPenalty: 0.6
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
frequencyPenalty: 0.6
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract presencePenalty parameter', () => {
const customParams = {
presencePenalty: 0.4
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
presencePenalty: 0.4
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract stopSequences parameter', () => {
const customParams = {
stopSequences: ['HALT', 'TERMINATE']
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
stopSequences: ['HALT', 'TERMINATE']
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract seed parameter', () => {
const customParams = {
seed: 12345
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
seed: 12345
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract maxOutputTokens parameter', () => {
const customParams = {
maxOutputTokens: 2048
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 2048
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract topP parameter', () => {
const customParams = {
topP: 0.95
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topP: 0.95
})
expect(result.providerParams).toStrictEqual({})
})
})
describe('Negative cases - Provider-specific parameters', () => {
it('should place all non-standard parameters in providerParams', () => {
const customParams = {
customParam: 'value',
anotherParam: 123,
thirdParam: true
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
customParam: 'value',
anotherParam: 123,
thirdParam: true
})
})
it('should place single provider-specific parameter in providerParams', () => {
const customParams = {
reasoningEffort: 'high'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
reasoningEffort: 'high'
})
})
it('should place model-specific parameter in providerParams', () => {
const customParams = {
thinking: { type: 'enabled', budgetTokens: 5000 }
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
thinking: { type: 'enabled', budgetTokens: 5000 }
})
})
it('should place serviceTier in providerParams', () => {
const customParams = {
serviceTier: 'auto'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
serviceTier: 'auto'
})
})
it('should place textVerbosity in providerParams', () => {
const customParams = {
textVerbosity: 'high'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
textVerbosity: 'high'
})
})
})
describe('Mixed parameters', () => {
it('should correctly separate mixed standard and provider-specific parameters', () => {
const customParams = {
temperature: 0.7,
topK: 40,
customParam: 'custom_value',
reasoningEffort: 'medium',
frequencyPenalty: 0.5,
seed: 999
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40,
frequencyPenalty: 0.5,
seed: 999
})
expect(result.providerParams).toStrictEqual({
customParam: 'custom_value',
reasoningEffort: 'medium'
})
})
it('should handle complex mixed parameters with nested objects', () => {
const customParams = {
topP: 0.9,
presencePenalty: 0.3,
thinking: { type: 'enabled', budgetTokens: 5000 },
stopSequences: ['STOP'],
serviceTier: 'auto',
maxOutputTokens: 4096
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topP: 0.9,
presencePenalty: 0.3,
stopSequences: ['STOP'],
maxOutputTokens: 4096
})
expect(result.providerParams).toStrictEqual({
thinking: { type: 'enabled', budgetTokens: 5000 },
serviceTier: 'auto'
})
})
it('should handle all standard params with some provider params', () => {
const customParams = {
maxOutputTokens: 2000,
temperature: 0.8,
topP: 0.95,
topK: 50,
presencePenalty: 0.6,
frequencyPenalty: 0.4,
stopSequences: ['END', 'DONE'],
seed: 777,
customApiParam: 'value',
anotherCustomParam: 123
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 2000,
temperature: 0.8,
topP: 0.95,
topK: 50,
presencePenalty: 0.6,
frequencyPenalty: 0.4,
stopSequences: ['END', 'DONE'],
seed: 777
})
expect(result.providerParams).toStrictEqual({
customApiParam: 'value',
anotherCustomParam: 123
})
})
})
describe('Edge cases', () => {
it('should handle empty object', () => {
const customParams = {}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({})
})
it('should handle zero values for numeric parameters', () => {
const customParams = {
temperature: 0,
topK: 0,
seed: 0
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0,
topK: 0,
seed: 0
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle negative values for numeric parameters', () => {
const customParams = {
presencePenalty: -0.5,
frequencyPenalty: -0.3,
seed: -1
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
presencePenalty: -0.5,
frequencyPenalty: -0.3,
seed: -1
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle empty arrays for stopSequences', () => {
const customParams = {
stopSequences: []
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
stopSequences: []
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle null values in mixed parameters', () => {
const customParams = {
temperature: 0.7,
customNull: null,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customNull: null
})
})
it('should handle undefined values in mixed parameters', () => {
const customParams = {
temperature: 0.7,
customUndefined: undefined,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customUndefined: undefined
})
})
it('should handle boolean values for standard parameters', () => {
const customParams = {
temperature: 0.7,
customBoolean: false,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customBoolean: false
})
})
it('should handle very large numeric values', () => {
const customParams = {
maxOutputTokens: 999999,
seed: 2147483647,
topK: 10000
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 999999,
seed: 2147483647,
topK: 10000
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle decimal values with high precision', () => {
const customParams = {
temperature: 0.123456789,
topP: 0.987654321,
presencePenalty: 0.111111111
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.123456789,
topP: 0.987654321,
presencePenalty: 0.111111111
})
expect(result.providerParams).toStrictEqual({})
})
})
describe('Case sensitivity', () => {
it('should NOT extract parameters with incorrect case - uppercase first letter', () => {
const customParams = {
Temperature: 0.7,
TopK: 40,
FrequencyPenalty: 0.5
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
Temperature: 0.7,
TopK: 40,
FrequencyPenalty: 0.5
})
})
it('should NOT extract parameters with incorrect case - all uppercase', () => {
const customParams = {
TEMPERATURE: 0.7,
TOPK: 40,
SEED: 42
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
TEMPERATURE: 0.7,
TOPK: 40,
SEED: 42
})
})
it('should NOT extract parameters with incorrect case - all lowercase', () => {
const customParams = {
maxoutputtokens: 1000,
frequencypenalty: 0.5,
stopsequences: ['STOP']
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
maxoutputtokens: 1000,
frequencypenalty: 0.5,
stopsequences: ['STOP']
})
})
it('should correctly extract exact case match while rejecting incorrect case', () => {
const customParams = {
temperature: 0.7,
Temperature: 0.8,
TEMPERATURE: 0.9,
topK: 40,
TopK: 50
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
Temperature: 0.8,
TEMPERATURE: 0.9,
TopK: 50
})
})
})
describe('Parameter name variations', () => {
it('should NOT extract similar but incorrect parameter names', () => {
const customParams = {
temp: 0.7, // should not match temperature
top_k: 40, // should not match topK
max_tokens: 1000, // should not match maxOutputTokens
freq_penalty: 0.5 // should not match frequencyPenalty
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
temp: 0.7,
top_k: 40,
max_tokens: 1000,
freq_penalty: 0.5
})
})
it('should NOT extract snake_case versions of standard parameters', () => {
const customParams = {
top_k: 40,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
stop_sequences: ['STOP'],
max_output_tokens: 1000
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
top_k: 40,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
stop_sequences: ['STOP'],
max_output_tokens: 1000
})
})
it('should extract exact camelCase parameters only', () => {
const customParams = {
topK: 40, // correct
top_k: 50, // incorrect
topP: 0.9, // correct
top_p: 0.8, // incorrect
frequencyPenalty: 0.5, // correct
frequency_penalty: 0.4 // incorrect
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topK: 40,
topP: 0.9,
frequencyPenalty: 0.5
})
expect(result.providerParams).toStrictEqual({
top_k: 50,
top_p: 0.8,
frequency_penalty: 0.4
})
})
})
})

View File

@@ -1,28 +0,0 @@
/**
* image.ts Unit Tests
* Tests for Gemini image generation utilities
*/
import { describe, expect, it } from 'vitest'
import { buildGeminiGenerateImageParams } from '../image'
describe('image utils', () => {
describe('buildGeminiGenerateImageParams', () => {
it('should return correct response modalities', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toEqual({
responseModalities: ['TEXT', 'IMAGE']
})
})
it('should return an object with responseModalities property', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toHaveProperty('responseModalities')
expect(Array.isArray(result.responseModalities)).toBe(true)
expect(result.responseModalities).toHaveLength(2)
})
})
})

View File

@@ -1,573 +0,0 @@
/**
* mcp.ts Unit Tests
* Tests for MCP tools configuration and conversion utilities
*/
import type { MCPTool } from '@renderer/types'
import type { Tool } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { convertMcpToolsToAiSdkTools, hasMultimodalContent, mcpResultToTextSummary, setupToolsConfig } from '../mcp'
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/utils/mcp-tools', () => ({
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
isToolAutoApproved: vi.fn(() => false),
callMCPTool: vi.fn(async () => ({
content: [{ type: 'text', text: 'Tool executed successfully' }],
isError: false
}))
}))
vi.mock('@renderer/utils/userConfirmation', () => ({
requestToolConfirmation: vi.fn(async () => true),
sendToolApprovalNotification: vi.fn(),
setToolIdToNameMapping: vi.fn(),
confirmSameNameTools: vi.fn()
}))
describe('mcp utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('setupToolsConfig', () => {
it('should return undefined when no MCP tools provided', () => {
const result = setupToolsConfig()
expect(result).toBeUndefined()
})
it('should return undefined when empty MCP tools array provided', () => {
const result = setupToolsConfig([])
expect(result).toBeUndefined()
})
it('should convert MCP tools to AI SDK tools format', () => {
const mcpTools: MCPTool[] = [
{
id: 'test-tool-1',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-tool',
description: 'A test tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' }
}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
// Tools are now keyed by id (which includes serverId suffix) for uniqueness
expect(Object.keys(result!)).toEqual(['test-tool-1'])
expect(result!['test-tool-1']).toHaveProperty('description')
expect(result!['test-tool-1']).toHaveProperty('inputSchema')
expect(result!['test-tool-1']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
const mcpTools: MCPTool[] = [
{
id: 'tool1-id',
serverId: 'server1',
serverName: 'server1',
name: 'tool1',
description: 'First tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
},
{
id: 'tool2-id',
serverId: 'server2',
serverName: 'server2',
name: 'tool2',
description: 'Second tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
// Tools are keyed by id for uniqueness
expect(Object.keys(result!)).toEqual(['tool1-id', 'tool2-id'])
})
})
describe('convertMcpToolsToAiSdkTools', () => {
it('should convert single MCP tool to AI SDK tool', () => {
const mcpTools: MCPTool[] = [
{
id: 'get-weather-id',
serverId: 'weather-server',
serverName: 'weather-server',
name: 'get-weather',
description: 'Get weather information',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness when multiple server instances exist
expect(Object.keys(result)).toEqual(['get-weather-id'])
const tool = result['get-weather-id'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should handle tool without description', () => {
const mcpTools: MCPTool[] = [
{
id: 'no-desc-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'no-desc-tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool-id'])
const tool = result['no-desc-tool-id'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
it('should convert empty tools array', () => {
const result = convertMcpToolsToAiSdkTools([])
expect(result).toEqual({})
})
it('should handle complex input schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'complex-tool-id',
serverId: 'server',
serverName: 'server',
name: 'complex-tool',
description: 'Tool with complex schema',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
tags: {
type: 'array',
items: { type: 'string' }
},
metadata: {
type: 'object',
properties: {
key: { type: 'string' }
}
}
},
required: ['name']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool-id'])
const tool = result['complex-tool-id'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool id with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
serverId: 'server',
serverName: 'server',
name: 'tool_with-special.chars',
description: 'Special chars tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness
expect(Object.keys(result)).toEqual(['special-tool-id'])
})
it('should handle multiple tools with different schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'string-tool-id',
serverId: 'server1',
serverName: 'server1',
name: 'string-tool',
description: 'String tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' }
}
}
},
{
id: 'number-tool-id',
serverId: 'server2',
serverName: 'server2',
name: 'number-tool',
description: 'Number tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
count: { type: 'number' }
}
}
},
{
id: 'boolean-tool-id',
serverId: 'server3',
serverName: 'server3',
name: 'boolean-tool',
description: 'Boolean tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
enabled: { type: 'boolean' }
}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness
expect(Object.keys(result).sort()).toEqual(['boolean-tool-id', 'number-tool-id', 'string-tool-id'])
expect(result['string-tool-id']).toBeDefined()
expect(result['number-tool-id']).toBeDefined()
expect(result['boolean-tool-id']).toBeDefined()
})
})
describe('hasMultimodalContent', () => {
it('should return false for pure text content', () => {
expect(hasMultimodalContent({ content: [{ type: 'text', text: 'hello' }] })).toBe(false)
})
it('should return true for image content', () => {
expect(hasMultimodalContent({ content: [{ type: 'image', data: 'base64...', mimeType: 'image/png' }] })).toBe(
true
)
})
it('should return true for audio content', () => {
expect(hasMultimodalContent({ content: [{ type: 'audio', data: 'base64...', mimeType: 'audio/mp3' }] })).toBe(
true
)
})
it('should return true for resource with blob', () => {
expect(
hasMultimodalContent({
content: [{ type: 'resource', resource: { blob: 'base64...', mimeType: 'image/png', uri: 'file://a.png' } }]
})
).toBe(true)
})
it('should return false for resource without blob', () => {
expect(
hasMultimodalContent({
content: [{ type: 'resource', resource: { text: 'plain text', uri: 'file://a.txt' } }]
})
).toBe(false)
})
it('should return true for mixed content with at least one multimodal item', () => {
expect(
hasMultimodalContent({
content: [
{ type: 'text', text: 'hello' },
{ type: 'image', data: 'base64...', mimeType: 'image/png' }
]
})
).toBe(true)
})
it('should return false for empty content array', () => {
expect(hasMultimodalContent({ content: [] })).toBe(false)
})
it('should return false for null/undefined result', () => {
expect(hasMultimodalContent(null as any)).toBe(false)
expect(hasMultimodalContent(undefined as any)).toBe(false)
})
it('should return false when content is not an array', () => {
expect(hasMultimodalContent({ content: 'not-array' } as any)).toBe(false)
})
})
describe('mcpResultToTextSummary', () => {
it('should extract text from text content', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'text', text: 'hello world' }] })).toBe('hello world')
})
it('should replace image with placeholder', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'image', data: 'base64...', mimeType: 'image/jpeg' }] })).toBe(
'[Image: image/jpeg, delivered to user]'
)
})
it('should use default mimeType for image without mimeType', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'image', data: 'base64...' }] })).toBe(
'[Image: image/png, delivered to user]'
)
})
it('should replace audio with placeholder', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'audio', data: 'base64...', mimeType: 'audio/wav' }] })).toBe(
'[Audio: audio/wav, delivered to user]'
)
})
it('should use default mimeType for audio without mimeType', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'audio', data: 'base64...' }] })).toBe(
'[Audio: audio/mp3, delivered to user]'
)
})
it('should replace resource with blob with placeholder', () => {
expect(
mcpResultToTextSummary({
content: [
{ type: 'resource', resource: { blob: 'base64...', mimeType: 'application/pdf', uri: 'file://doc.pdf' } }
]
})
).toBe('[Resource: application/pdf, uri=file://doc.pdf, delivered to user]')
})
it('should use resource text when no blob', () => {
expect(
mcpResultToTextSummary({
content: [{ type: 'resource', resource: { text: 'resource content', uri: 'file://a.txt' } }]
})
).toBe('resource content')
})
it('should JSON.stringify unknown content types', () => {
const item = { type: 'unknown' as any, foo: 'bar' }
expect(mcpResultToTextSummary({ content: [item] })).toBe(JSON.stringify(item))
})
it('should join multiple content parts with newline', () => {
const result = mcpResultToTextSummary({
content: [
{ type: 'text', text: 'Description' },
{ type: 'image', data: 'base64...', mimeType: 'image/png' }
]
})
expect(result).toBe('Description\n[Image: image/png, delivered to user]')
})
it('should JSON.stringify result when content is missing', () => {
expect(mcpResultToTextSummary(null as any)).toBe('null')
expect(mcpResultToTextSummary({} as any)).toBe('{}')
})
it('should handle empty text gracefully', () => {
expect(mcpResultToTextSummary({ content: [{ type: 'text' }] })).toBe('')
})
})
describe('tool execution', () => {
it('should execute tool with user confirmation', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'test-exec-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-exec-tool',
description: 'Test execution tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool-id'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
})
it('should handle user cancellation', async () => {
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
const mcpTools: MCPTool[] = [
{
id: 'cancelled-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'cancelled-tool',
description: 'Tool to cancel',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool-id'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).not.toHaveBeenCalled()
expect(result).toEqual({
content: [
{
type: 'text',
text: 'User declined to execute tool "cancelled-tool".'
}
],
isError: false
})
})
it('should handle tool execution error', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
const mcpTools: MCPTool[] = [
{
id: 'error-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'error-tool',
description: 'Tool that errors',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool-id'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
).rejects.toEqual({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
})
it('should auto-approve when enabled', async () => {
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(isToolAutoApproved).mockReturnValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'auto-approve-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'auto-approve-tool',
description: 'Auto-approved tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool-id'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -1,288 +0,0 @@
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { getReasoningEffort } from '../reasoning'
// Mock logger
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
warn: vi.fn(),
info: vi.fn(),
error: vi.fn()
})
}
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
vi.mock('@renderer/store/assistants', () => {
const mockAssistantsSlice = {
name: 'assistants',
reducer: vi.fn((state = { entities: {}, ids: [] }) => state),
actions: {
updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' }))
}
}
return {
default: mockAssistantsSlice.reducer,
updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' })),
assistantsSlice: mockAssistantsSlice
}
})
// Mock provider service
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: (model: Model) => ({
id: model.provider,
name: 'Poe',
type: 'openai'
}),
getAssistantSettings: (assistant: Assistant) => assistant.settings || {}
}))
describe('Poe Provider Reasoning Support', () => {
const createPoeModel = (id: string): Model => ({
id,
name: id,
provider: SystemProviderIds.poe,
group: 'poe'
})
const createAssistant = (reasoning_effort?: ReasoningEffortOption, maxTokens?: number): Assistant => ({
id: 'test-assistant',
name: 'Test Assistant',
emoji: '🤖',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {
reasoning_effort,
maxTokens
}
})
describe('GPT-5 Series Models', () => {
it('should return reasoning_effort in extra_body for GPT-5 model with low effort', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant('low')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
reasoning_effort: 'low'
}
})
})
it('should return reasoning_effort in extra_body for GPT-5 model with medium effort', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant('medium')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
reasoning_effort: 'medium'
}
})
})
it('should return reasoning_effort in extra_body for GPT-5 model with high effort', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant('high')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
reasoning_effort: 'high'
}
})
})
it('should convert auto to medium for GPT-5 model in extra_body', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant('auto')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
reasoning_effort: 'medium'
}
})
})
it('should return reasoning_effort in extra_body for GPT-5.1 model', () => {
const model = createPoeModel('gpt-5.1')
const assistant = createAssistant('medium')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
reasoning_effort: 'medium'
}
})
})
})
describe('Claude Models', () => {
it('should return thinking_budget in extra_body for Claude 3.7 Sonnet', () => {
const model = createPoeModel('claude-3.7-sonnet')
const assistant = createAssistant('medium', 4096)
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('extra_body')
expect(result.extra_body).toHaveProperty('thinking_budget')
expect(typeof result.extra_body?.thinking_budget).toBe('number')
expect(result.extra_body?.thinking_budget).toBeGreaterThan(0)
})
it('should return thinking_budget in extra_body for Claude Sonnet 4', () => {
const model = createPoeModel('claude-sonnet-4')
const assistant = createAssistant('high', 8192)
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('extra_body')
expect(result.extra_body).toHaveProperty('thinking_budget')
expect(typeof result.extra_body?.thinking_budget).toBe('number')
})
it('should calculate thinking_budget based on effort ratio and maxTokens', () => {
const model = createPoeModel('claude-3.7-sonnet')
const assistant = createAssistant('low', 4096)
const result = getReasoningEffort(assistant, model)
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
})
})
describe('Gemini Models', () => {
it('should return thinking_budget in extra_body for Gemini 2.5 Flash', () => {
const model = createPoeModel('gemini-2.5-flash')
const assistant = createAssistant('medium')
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('extra_body')
expect(result.extra_body).toHaveProperty('thinking_budget')
expect(typeof result.extra_body?.thinking_budget).toBe('number')
})
it('should return thinking_budget in extra_body for Gemini 2.5 Pro', () => {
const model = createPoeModel('gemini-2.5-pro')
const assistant = createAssistant('high')
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('extra_body')
expect(result.extra_body).toHaveProperty('thinking_budget')
})
it('should use -1 for auto effort', () => {
const model = createPoeModel('gemini-2.5-flash')
const assistant = createAssistant('auto')
const result = getReasoningEffort(assistant, model)
expect(result.extra_body?.thinking_budget).toBe(-1)
})
it('should calculate thinking_budget for non-auto effort', () => {
const model = createPoeModel('gemini-2.5-flash')
const assistant = createAssistant('low')
const result = getReasoningEffort(assistant, model)
expect(typeof result.extra_body?.thinking_budget).toBe('number')
})
})
describe('No Reasoning Effort', () => {
it('should return empty object when reasoning_effort is not set', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant(undefined)
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should return empty object when reasoning_effort is "none"', () => {
const model = createPoeModel('gpt-5')
const assistant = createAssistant('none')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('Non-Reasoning Models', () => {
it('should return empty object for non-reasoning models', () => {
const model = createPoeModel('gpt-4')
const assistant = createAssistant('medium')
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('Edge Cases: Models Without Token Limit Configuration', () => {
it('should return empty object for Claude models without token limit configuration', () => {
const model = createPoeModel('claude-unknown-variant')
const assistant = createAssistant('medium', 4096)
const result = getReasoningEffort(assistant, model)
// Should return empty object when token limit is not found
expect(result).toEqual({})
expect(result.extra_body?.thinking_budget).toBeUndefined()
})
it('should return empty object for unmatched Poe reasoning models', () => {
// A hypothetical reasoning model that doesn't match GPT-5, Claude, or Gemini
const model = createPoeModel('some-reasoning-model')
// Make it appear as a reasoning model by giving it a name that won't match known categories
const assistant = createAssistant('medium')
const result = getReasoningEffort(assistant, model)
// Should return empty object for unmatched models
expect(result).toEqual({})
})
it('should fallback to -1 for Gemini models without token limit', () => {
// Use a Gemini model variant that won't match any token limit pattern
// The current regex patterns cover gemini-.*-flash.*$ and gemini-.*-pro.*$
// so we need a model that matches isSupportedThinkingTokenGeminiModel but not THINKING_TOKEN_MAP
const model = createPoeModel('gemini-2.5-flash')
const assistant = createAssistant('auto')
const result = getReasoningEffort(assistant, model)
// For 'auto' effort, should use -1
expect(result.extra_body?.thinking_budget).toBe(-1)
})
it('should enforce minimum 1024 token floor for Claude models', () => {
const model = createPoeModel('claude-3.7-sonnet')
// Use very small maxTokens to test the minimum floor
const assistant = createAssistant('low', 100)
const result = getReasoningEffort(assistant, model)
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
})
it('should handle undefined maxTokens for Claude models', () => {
const model = createPoeModel('claude-3.7-sonnet')
const assistant = createAssistant('medium', undefined)
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('extra_body')
expect(result.extra_body).toHaveProperty('thinking_budget')
expect(typeof result.extra_body?.thinking_budget).toBe('number')
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
})
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -1,405 +0,0 @@
/**
* websearch.ts Unit Tests
* Tests for web search parameters generation utilities
*/
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
// Mock dependencies
vi.mock('@renderer/config/models', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
}))
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
}))
describe('websearch utils', () => {
describe('getWebSearchParams', () => {
it('should return enhancement params for hunyuan provider', () => {
const model: Model = {
id: 'hunyuan-model',
name: 'Hunyuan Model',
provider: 'hunyuan'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_enhancement: true,
citation: true,
search_info: true
})
})
it('should return search params for dashscope provider', () => {
const model: Model = {
id: 'qwen-model',
name: 'Qwen Model',
provider: 'dashscope'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_search: true,
search_options: {
forced_search: true
}
})
})
it('should return web_search_options for OpenAI web search models', () => {
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
web_search_options: {}
})
})
it('should return extra_body with web_search for poe provider', () => {
const model: Model = {
id: 'Gemini-3-Flash',
name: 'Gemini 3 Flash',
provider: 'poe'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
extra_body: {
web_search: true
}
})
})
it('should return empty object for other providers', () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
it('should return empty object for custom provider', () => {
const model: Model = {
id: 'custom-model',
name: 'Custom Model',
provider: 'custom-provider'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
})
describe('buildProviderBuiltinWebSearchConfig', () => {
const defaultWebSearchConfig: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
describe('openai provider', () => {
it('should return low search context size for low maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 20,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'low'
}
})
})
it('should return medium search context size for medium maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
it('should return high search context size for high maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 80,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'high'
}
})
})
it('should use medium for deep research models regardless of maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
})
describe('openai-chat provider', () => {
it('should return correct search context size', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
it('should handle deep research models', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
})
describe('anthropic provider', () => {
it('should return anthropic search options with maxUses', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
it('should include blockedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 30,
excludeDomains: ['example.com', 'test.com']
}
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
expect(result).toEqual({
anthropic: {
maxUses: 30,
blockedDomains: ['example.com', 'test.com']
}
})
})
it('should not include blockedDomains when empty', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
})
describe('xai provider', () => {
it('should return xai-responses search options with enableImageUnderstanding when no excludeDomains', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result).toEqual({
'xai-responses': {
webSearch: { enableImageUnderstanding: true },
xSearch: { enableImageUnderstanding: true }
}
})
})
it('should include excludedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result).toEqual({
'xai-responses': {
webSearch: {
enableImageUnderstanding: true,
excludedDomains: ['site1.com', 'site2.com']
},
xSearch: { enableImageUnderstanding: true }
}
})
})
it('should limit excluded domains to 5', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result?.['xai-responses']?.webSearch?.excludedDomains).toHaveLength(5)
})
})
describe('openrouter provider', () => {
it('should return openrouter plugins config', () => {
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 50
}
]
}
})
})
it('should respect custom maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 75,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 75
}
]
}
})
})
})
describe('unsupported provider', () => {
it('should return empty object for unsupported provider', () => {
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
expect(result).toEqual({})
})
it('should return empty object for google provider', () => {
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
expect(result).toEqual({})
})
})
describe('edge cases', () => {
it('should handle maxResults at boundary values', () => {
// Test boundary at 33 (low/medium)
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
expect(result33?.openai?.searchContextSize).toBe('low')
// Test boundary at 34 (medium)
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
expect(result34?.openai?.searchContextSize).toBe('medium')
// Test boundary at 66 (medium)
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
expect(result66?.openai?.searchContextSize).toBe('medium')
// Test boundary at 67 (high)
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
expect(result67?.openai?.searchContextSize).toBe('high')
})
it('should handle zero maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('low')
})
it('should handle very large maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('high')
})
})
})
})

View File

@@ -1,5 +0,0 @@
export function buildGeminiGenerateImageParams(): Record<string, any> {
return {
responseModalities: ['TEXT', 'IMAGE']
}
}

View File

@@ -1,197 +0,0 @@
import { loggerService } from '@logger'
import store from '@renderer/store'
import type { MCPCallToolResponse, MCPTool, MCPToolResponse } from '@renderer/types'
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
import {
confirmSameNameTools,
requestToolConfirmation,
sendToolApprovalNotification,
setToolIdToNameMapping
} from '@renderer/utils/userConfirmation'
import { type Tool, type ToolSet } from 'ai'
import { jsonSchema, tool } from 'ai'
import type { JSONSchema7 } from 'json-schema'
const logger = loggerService.withContext('MCP-utils')
// Setup tools configuration based on provided parameters
export function setupToolsConfig(
mcpTools?: MCPTool[],
allowedTools?: string[]
): Record<string, Tool<any, any>> | undefined {
let tools: ToolSet = {}
if (!mcpTools?.length) {
return undefined
}
tools = convertMcpToolsToAiSdkTools(mcpTools, allowedTools)
return tools
}
/**
* 检查 MCP 工具调用结果是否包含可能携带大体积 base64 数据的多模态内容。
* 包括 image、audio 以及含 blob 的 resource 类型。
*/
export function hasMultimodalContent(result: MCPCallToolResponse): boolean {
return (
Array.isArray(result?.content) &&
result.content.some(
(item) => item.type === 'image' || item.type === 'audio' || (item.type === 'resource' && !!item.resource?.blob)
)
)
}
/**
* 将 MCP 工具调用结果转换为纯文本摘要,把图片/音频/resource blob 替换为文本占位描述,
* 避免 base64 数据超出消息大小限制(如 kimi 的 4MB 限制)。
*/
export function mcpResultToTextSummary(result: MCPCallToolResponse): string {
if (!result || !result.content || !Array.isArray(result.content)) {
return JSON.stringify(result)
}
const parts: string[] = []
for (const item of result.content) {
switch (item.type) {
case 'text':
parts.push(item.text || '')
break
case 'image':
parts.push(`[Image: ${item.mimeType || 'image/png'}, delivered to user]`)
break
case 'audio':
parts.push(`[Audio: ${item.mimeType || 'audio/mp3'}, delivered to user]`)
break
case 'resource':
if (item.resource?.blob) {
parts.push(
`[Resource: ${item.resource.mimeType || 'application/octet-stream'}, uri=${item.resource.uri || 'unknown'}, delivered to user]`
)
} else {
parts.push(item.resource?.text || JSON.stringify(item))
}
break
default:
parts.push(JSON.stringify(item))
break
}
}
return parts.join('\n')
}
/**
* 将 MCPTool 转换为 AI SDK 工具格式
*/
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[], allowedTools?: string[]): ToolSet {
const tools: ToolSet = {}
for (const mcpTool of mcpTools) {
// Use mcpTool.id (which includes serverId suffix) to ensure uniqueness
// when multiple instances of the same MCP server type are configured
tools[mcpTool.id] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId }) => {
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
let isAutoApproveEnabled = isToolAutoApproved(mcpTool, server, allowedTools)
// For hub invoke/exec, resolve the underlying tool and check its server's auto-approve config
if (
!isAutoApproveEnabled &&
mcpTool.serverId === 'hub' &&
(mcpTool.name === 'invoke' || mcpTool.name === 'exec')
) {
const underlyingToolName = (params as Record<string, unknown>)?.name as string | undefined
if (underlyingToolName) {
try {
const resolved = await window.api.mcp.resolveHubTool(underlyingToolName)
if (resolved) {
const underlyingServer = store.getState().mcp.servers.find((s) => s.id === resolved.serverId)
if (underlyingServer) {
isAutoApproveEnabled = !underlyingServer.disabledAutoApproveTools?.includes(resolved.toolName)
}
}
} catch (err) {
logger.warn('Failed to resolve hub tool for auto-approve check', err as Error)
}
}
}
let confirmed = true
if (!isAutoApproveEnabled) {
// Register mapping so confirmSameNameTools can batch-confirm pending tools.
// For hub invoke/exec, use the underlying tool name so tools targeting the
// same underlying server+tool are grouped together.
const mappingName =
mcpTool.serverId === 'hub' && (mcpTool.name === 'invoke' || mcpTool.name === 'exec')
? ((params as Record<string, unknown>)?.name as string) || mcpTool.name
: mcpTool.name
setToolIdToNameMapping(toolCallId, mappingName)
// Send system notification for tool approval
sendToolApprovalNotification(mcpTool.name)
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolCallId)
if (confirmed) {
// Auto-confirm other pending tools with the same name
confirmSameNameTools(mappingName)
}
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
logger.debug(`Executing tool: ${mcpTool.name}`)
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {
id: toolCallId,
tool: mcpTool,
arguments: params,
status: 'pending',
toolCallId
}
const result = await callMCPTool(toolResponse)
// 返回结果AI SDK 会处理序列化
if (result.isError) {
return Promise.reject(result)
}
// 返回工具执行结果
return result
},
// 将多模态结果 (image/audio/resource blob) 转为文本摘要,避免 base64 超出消息大小限制。
// 图片/音频已通过 IMAGE_COMPLETE chunk 展示给用户。
// TODO: 待 AI SDK 支持 provider 感知后,可按 provider 返回 media 格式。
toModelOutput(rawOutput: unknown) {
// rawOutput 来自上方 execute 的 return result类型始终为 MCPCallToolResponse
// mcpResultToTextSummary 内部已有 null/content 校验,不会因意外输入崩溃
const result = rawOutput as MCPCallToolResponse
return { type: 'text' as const, value: mcpResultToTextSummary(result) }
}
})
}
return tools
}

View File

@@ -1,605 +0,0 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isAnthropicModel,
isGeminiModel,
isGrokModel,
isOpenAIModel,
isQwenMTModel,
isReasoningModel,
isSupportFlexServiceTierModel,
isSupportVerbosityModel
} from '@renderer/config/models'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getProviderById } from '@renderer/services/ProviderService'
import {
type Assistant,
type GroqServiceTier,
GroqServiceTiers,
type GroqSystemProvider,
isGroqServiceTier,
isGroqSystemProvider,
isOpenAIServiceTier,
isTranslateAssistant,
type Model,
type NotGroqProvider,
type OpenAIServiceTier,
OpenAIServiceTiers,
type Provider,
type ServiceTier,
SystemProviderIds
} from '@renderer/types'
import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next'
import { merge } from 'lodash'
import type { OllamaProviderOptions } from 'ollama-ai-provider-v2'
import { addAnthropicHeaders } from '../prepareParams/header'
import { getAiSdkProviderId } from '../provider/factory'
import type { ProviderCapabilities } from '../types'
import { buildGeminiGenerateImageParams } from './image'
import {
getAnthropicReasoningParams,
getBedrockReasoningParams,
getCustomParameters,
getGeminiReasoningParams,
getOllamaReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort,
getXAIReasoningParams
} from './reasoning'
import { getWebSearchParams } from './websearch'
const logger = loggerService.withContext('aiCore.utils.options')
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
if (
!isOpenAIServiceTier(serviceTier) ||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
if (
!isGroqServiceTier(serviceTier) ||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
return undefined
}
// 处理不同供应商需要 fallback 到默认值的情况
if (isGroqSystemProvider(provider)) {
return toGroqServiceTier(model, serviceTierSetting)
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
return toOpenAIServiceTier(model, serviceTierSetting)
}
}
function getVerbosity(model: Model): OpenAIVerbosity {
if (!isSupportVerbosityModel(model) || !isSupportVerbosityProvider(getProviderById(model.provider)!)) {
return undefined
}
const openAI = getStoreSetting('openAI')
const userVerbosity = openAI.verbosity
if (userVerbosity) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
return verbosity
}
return undefined
}
/**
* Extract AI SDK standard parameters from custom parameters
* These parameters should be passed directly to streamText() instead of providerOptions
*/
export function extractAiSdkStandardParams(customParams: Record<string, any>): {
standardParams: Partial<Record<AiSdkParam, any>>
providerParams: Record<string, any>
} {
const standardParams: Partial<Record<AiSdkParam, any>> = {}
const providerParams: Record<string, any> = {}
for (const [key, value] of Object.entries(customParams)) {
if (isAiSdkParam(key)) {
standardParams[key] = value
} else {
providerParams[key] = value
}
}
return { standardParams, providerParams }
}
/**
* 构建 AI SDK 的 providerOptions
* 按 provider 类型分离,保持类型安全
* 返回格式:{
* providerOptions: { 'providerId': providerOptions },
* standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed }
* }
*
* Custom parameters are split into two categories:
* 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText()
* 2. Provider-specific parameters - merged into providerOptions
*/
export function buildProviderOptions(
assistant: Assistant,
model: Model,
actualProvider: Provider,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): {
providerOptions: Record<string, Record<string, JSONValue>>
standardParams: Partial<Record<AiSdkParam, any>>
} {
const rawProviderId = getAiSdkProviderId(actualProvider)
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId })
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
const serviceTier = getServiceTier(model, actualProvider)
const textVerbosity = getVerbosity(model)
// 根据 provider ID 构建特定选项
switch (rawProviderId) {
case 'openai':
case 'openai-chat':
case 'azure':
case 'azure-responses':
case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
break
case 'anthropic':
case 'azure-anthropic':
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'google':
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'xai':
case 'xai-responses':
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
break
case 'bedrock':
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break
case 'cherryin':
providerSpecificOptions = buildCherryInProviderOptions(
assistant,
model,
capabilities,
actualProvider,
serviceTier,
textVerbosity
)
break
case SystemProviderIds.ollama:
providerSpecificOptions = buildOllamaProviderOptions(assistant, model, capabilities)
break
case SystemProviderIds.gateway:
providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
break
case 'deepseek':
case 'openrouter':
case 'openai-compatible':
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
// Merge serviceTier and textVerbosity
providerSpecificOptions = {
...providerSpecificOptions,
[rawProviderId]: {
...providerSpecificOptions[rawProviderId],
serviceTier,
textVerbosity
}
}
break
}
logger.debug('Built providerSpecificOptions', { providerSpecificOptions })
/**
* Retrieve custom parameters and separate standard parameters from provider-specific parameters.
*/
const customParams = getCustomParameters(assistant)
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams })
/**
* Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions.
* For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic')
* For regular providers, this will be the provider itself
*/
const actualAiSdkProviderIds = Object.keys(providerSpecificOptions)
const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params
// For openai-compatible providers, auto-convert reasoning_effort (snake_case) to reasoningEffort (camelCase).
// The AI SDK's openai-compatible provider overwrites reasoning_effort to undefined,
// but accepts reasoningEffort. See: https://github.com/CherryHQ/cherry-studio/issues/11987
if (primaryAiSdkProviderId === 'openai-compatible' && 'reasoning_effort' in providerParams) {
if (!('reasoningEffort' in providerParams)) {
providerParams.reasoningEffort = providerParams.reasoning_effort
}
delete providerParams.reasoning_effort
}
/**
* Merge custom parameters into providerSpecificOptions.
* Simple logic:
* 1. If key is in actualAiSdkProviderIds → merge directly (user knows the actual AI SDK provider ID)
* 2. If key == rawProviderId:
* - If it's gateway/ollama → preserve (they need their own config for routing/options)
* - Otherwise → map to primary (this is a proxy provider like cherryin)
* 3. Otherwise → treat as regular parameter, merge to primary provider
*
* Example:
* - User writes `cherryin: { opt: 'val' }` → mapped to `google: { opt: 'val' }` (case 2, proxy)
* - User writes `gateway: { order: [...] }` → stays as `gateway: { order: [...] }` (case 2, routing config)
* - User writes `google: { opt: 'val' }` → stays as `google: { opt: 'val' }` (case 1)
* - User writes `customKey: 'val'` → merged to `google: { customKey: 'val' }` (case 3)
*/
for (const key of Object.keys(providerParams)) {
if (actualAiSdkProviderIds.includes(key)) {
// Case 1: Key is an actual AI SDK provider ID - merge directly
providerSpecificOptions = {
...providerSpecificOptions,
[key]: {
...providerSpecificOptions[key],
...providerParams[key]
}
}
} else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) {
// Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider)
// Gateway is special: it needs routing config preserved
if (key === SystemProviderIds.gateway) {
// Preserve gateway config for routing
providerSpecificOptions = {
...providerSpecificOptions,
[key]: {
...providerSpecificOptions[key],
...providerParams[key]
}
}
} else {
// Proxy provider (cherryin, etc.) - map to actual AI SDK provider
providerSpecificOptions = {
...providerSpecificOptions,
[primaryAiSdkProviderId]: {
...providerSpecificOptions[primaryAiSdkProviderId],
...providerParams[key]
}
}
}
} else {
// Case 3: Regular parameter - merge to primary provider
providerSpecificOptions = {
...providerSpecificOptions,
[primaryAiSdkProviderId]: {
...providerSpecificOptions[primaryAiSdkProviderId],
[key]: providerParams[key]
}
}
}
}
logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions })
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
return {
providerOptions: providerSpecificOptions,
standardParams
}
}
/**
* 构建 OpenAI 特定的 providerOptions
*/
function buildOpenAIProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
serviceTier: OpenAIServiceTier,
textVerbosity?: OpenAIVerbosity
): Record<string, OpenAIResponsesProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams,
// TODO: Remove this workaround after migrating to @ai-sdk/open-responses (#13462)
// Bypass @ai-sdk/openai's model ID allowlist for reasoning detection.
// Third-party providers often use non-canonical model IDs (e.g., "openai/gpt-5.2")
// that fail the SDK's startsWith() checks, causing reasoning params to be silently dropped.
...(isReasoningModel(model) && { forceReasoning: true })
}
}
const provider = getProviderById(model.provider)
if (!provider) {
throw new Error(`Provider ${model.provider} not found`)
}
if (isSupportVerbosityModel(model) && isSupportVerbosityProvider(provider)) {
const openAI = getStoreSetting<'openAI'>('openAI')
const userVerbosity = openAI?.verbosity
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
providerOptions = {
...providerOptions,
textVerbosity: verbosity
}
}
}
// TODO: 支持配置是否在服务端持久化
providerOptions = {
...providerOptions,
serviceTier,
textVerbosity,
store: false
}
return {
openai: providerOptions
}
}
/**
* 构建 Anthropic 特定的 providerOptions
*/
function buildAnthropicProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, AnthropicProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数
if (enableReasoning) {
const reasoningParams = getAnthropicReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
return {
anthropic: {
...providerOptions
}
}
}
/**
* 构建 Gemini 特定的 providerOptions
*/
function buildGeminiProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, GoogleGenerativeAIProviderOptions> {
const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: GoogleGenerativeAIProviderOptions = {}
// Gemini 推理参数
if (enableReasoning) {
const reasoningParams = getGeminiReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
if (enableGenerateImage) {
providerOptions = {
...providerOptions,
...buildGeminiGenerateImageParams()
}
}
return {
google: {
...providerOptions
}
}
}
function buildXAIProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, XaiProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
if (enableReasoning) {
const reasoningParams = getXAIReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
return {
xai: {
...providerOptions
}
}
}
function buildCherryInProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
actualProvider: Provider,
serviceTier: OpenAIServiceTier,
textVerbosity: OpenAIVerbosity
): Record<string, OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions> {
switch (actualProvider.type) {
case 'openai':
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
case 'openai-response':
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities)
case 'gemini':
return buildGeminiProviderOptions(assistant, model, capabilities)
default:
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
}
}
/**
* Build Bedrock providerOptions
*/
function buildBedrockProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, BedrockProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: BedrockProviderOptions = {}
if (enableReasoning) {
const reasoningParams = getBedrockReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
const betaHeaders = addAnthropicHeaders(assistant, model)
if (betaHeaders.length > 0) {
providerOptions.anthropicBeta = betaHeaders
}
return {
bedrock: providerOptions
}
}
function buildOllamaProviderOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, OllamaProviderOptions> {
const { enableReasoning } = capabilities
let options = {}
if (enableReasoning) {
options = {
...options,
...getOllamaReasoningParams(assistant, model)
}
}
return { ollama: options }
}
/**
* 构建通用的 providerOptions用于其他 provider
*/
function buildGenericProviderOptions(
providerId: string,
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
): Record<string, any> {
const { enableWebSearch } = capabilities
let providerOptions: Record<string, any> = {}
const reasoningParams = getReasoningEffort(assistant, model)
logger.debug('reasoningParams', reasoningParams)
providerOptions = {
...providerOptions,
...reasoningParams
}
if (enableWebSearch) {
const webSearchParams = getWebSearchParams(model)
providerOptions = merge({}, providerOptions, webSearchParams)
}
// 特殊处理 Qwen MT
if (isQwenMTModel(model)) {
if (isTranslateAssistant(assistant)) {
const targetLanguage = assistant.targetLanguage
const translationOptions = {
source_lang: 'auto',
target_lang: mapLanguageToQwenMTModel(targetLanguage)
} as const
if (!translationOptions.target_lang) {
throw new Error(t('translate.error.not_supported', { language: targetLanguage.value }))
}
providerOptions.translation_options = translationOptions
} else {
throw new Error(t('translate.error.chat_qwen_mt'))
}
}
return {
[providerId]: providerOptions
}
}
function buildAIGatewayOptions(
assistant: Assistant,
model: Model,
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
serviceTier: OpenAIServiceTier,
textVerbosity?: OpenAIVerbosity
): Record<
string,
| OpenAIResponsesProviderOptions
| AnthropicProviderOptions
| GoogleGenerativeAIProviderOptions
| Record<string, unknown>
> {
if (isAnthropicModel(model)) {
return buildAnthropicProviderOptions(assistant, model, capabilities)
} else if (isOpenAIModel(model)) {
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
} else if (isGeminiModel(model)) {
return buildGeminiProviderOptions(assistant, model, capabilities)
} else if (isGrokModel(model)) {
return buildXAIProviderOptions(assistant, model, capabilities)
} else {
return buildGenericProviderOptions('openai-compatible', assistant, model, capabilities)
}
}

View File

@@ -1,999 +0,0 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getModelSupportedReasoningEffortOptions,
isClaude46SeriesModel,
isDeepSeekHybridInferenceModel,
isDoubaoSeed18Model,
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGemini3ThinkingTokenModel,
isGrok4FastReasoningModel,
isOpenAIDeepResearchModel,
isOpenAIModel,
isOpenAIOpenWeightModel,
isOpenAIReasoningModel,
isQwen35to39Model,
isQwenAlwaysThinkModel,
isQwenReasoningModel,
isReasoningModel,
isSupportedReasoningEffortGrokModel,
isSupportedReasoningEffortModel,
isSupportedReasoningEffortOpenAIModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenHunyuanModel,
isSupportedThinkingTokenKimiModel,
isSupportedThinkingTokenMiMoModel,
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenZhipuModel,
isSupportNoneReasoningEffortModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
import type { OpenAIReasoningEffort, OpenAIReasoningSummary } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { toInteger } from 'lodash'
import type { OllamaProviderOptions } from 'ollama-ai-provider-v2'
const logger = loggerService.withContext('reasoning')
type ReasoningEffortOptionalParams = {
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string; enabled?: boolean } | OpenAI.Reasoning
reasoningEffort?: OpenAIReasoningEffort
// WARN: This field will be overwrite to undefined by aisdk if the provider is openai-compatible. Use reasoningEffort instead.
reasoning_effort?: OpenAIReasoningEffort
enable_thinking?: boolean
thinking_budget?: number
incremental_output?: boolean
enable_reasoning?: boolean
// nvidia, etc.
chat_template_kwargs?: {
thinking?: boolean
enable_thinking?: boolean
thinking_budget?: number
}
extra_body?: {
google?: {
thinking_config: {
thinking_budget: number
include_thoughts?: boolean
}
}
thinking?: {
type: 'enabled' | 'disabled'
}
thinking_budget?: number
reasoning_effort?: OpenAIReasoningEffort
}
disable_reasoning?: boolean
// Add any other potential reasoning-related keys here if they exist
}
// The function is only for generic provider. May extract some logics to independent provider
export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
const provider = getProviderByModel(model)
const modelId = getLowerBaseModelName(model.id)
if (provider.id === 'groq') {
return {}
}
if (!isReasoningModel(model)) {
return {}
}
if (isOpenAIDeepResearchModel(model)) {
return {
reasoning_effort: 'medium'
}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
// reasoningEffort is not set, no extra reasoning setting
// Generally, for every model which supports reasoning control, the reasoning effort won't be undefined.
// It's for some reasoning models that don't support reasoning control, such as deepseek reasoner.
if (!reasoningEffort || reasoningEffort === 'default') {
return {}
}
// Handle 'none' reasoningEffort. It's explicitly off.
if (reasoningEffort === 'none') {
// openrouter: use reasoning
if (model.provider === SystemProviderIds.openrouter) {
if (isSupportNoneReasoningEffortModel(model) && reasoningEffort === 'none') {
return { reasoning: { effort: 'none' } }
}
return { reasoning: { enabled: false, exclude: true } }
}
// nvidia: must use chat_template_kwargs
// Since limited documentation, it's hard to find what parameters should be set
// only part of mainstream oss model covered, all verified by nvidia api
if (model.provider === SystemProviderIds.nvidia) {
if (isSupportedThinkingTokenQwenModel(model)) {
return { chat_template_kwargs: { enable_thinking: false } }
} else if (isDeepSeekHybridInferenceModel(model)) {
return { chat_template_kwargs: { thinking: false } }
} else if (isSupportedThinkingTokenKimiModel(model)) {
return { chat_template_kwargs: { thinking: false } }
} else if (isSupportedThinkingTokenZhipuModel(model)) {
return { chat_template_kwargs: { enable_thinking: false } }
}
}
// providers that use enable_thinking
if (
(isSupportEnableThinkingProvider(provider) &&
(isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model))) ||
(provider.id === SystemProviderIds.dashscope &&
(isDeepSeekHybridInferenceModel(model) ||
isSupportedThinkingTokenZhipuModel(model) ||
isSupportedThinkingTokenKimiModel(model)))
) {
return { enable_thinking: false }
}
// together
if (provider.id === SystemProviderIds.together) {
return { reasoning: { enabled: false } }
}
// gemini
if (isSupportedThinkingTokenGeminiModel(model)) {
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
}
} else {
logger.warn(`Model ${model.id} cannot disable reasoning. Fallback to empty reasoning param.`)
return {}
}
}
// use thinking, doubao, zhipu, etc.
if (
isSupportedThinkingTokenDoubaoModel(model) ||
isSupportedThinkingTokenZhipuModel(model) ||
isSupportedThinkingTokenMiMoModel(model) ||
isSupportedThinkingTokenKimiModel(model)
) {
if (provider.id === SystemProviderIds.cerebras) {
return {
disable_reasoning: true
}
}
return { thinking: { type: 'disabled' } }
}
// Deepseek, default behavior is non-thinking
if (isDeepSeekHybridInferenceModel(model)) {
return {}
}
// GPT 5.1, GPT 5.2, or newer
if (isSupportNoneReasoningEffortModel(model)) {
return {
reasoningEffort: 'none'
}
}
// Qwen 3.5 without direct enable_thinking
// https://huggingface.co/Qwen/Qwen3.5-397B-A17B#instruct-or-non-thinking-mode
if (isQwen35to39Model(model)) {
return {
chat_template_kwargs: {
enable_thinking: false
}
}
}
logger.warn(`Model ${model.id} doesn't match any disable reasoning behavior. Fallback to empty reasoning param.`)
return {}
}
// reasoningEffort有效的情况
// https://creator.poe.com/docs/external-applications/openai-compatible-api#additional-considerations
// Poe provider - supports custom bot parameters via extra_body
if (provider.id === SystemProviderIds.poe) {
if (isOpenAIReasoningModel(model)) {
return {
extra_body: {
reasoning_effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
}
}
}
// Claude models use thinking_budget parameter in extra_body
if (isSupportedThinkingTokenClaudeModel(model)) {
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(model.id)
const maxTokens = assistant.settings?.maxTokens
if (!tokenLimit) {
logger.warn(
`No token limit configuration found for Claude model "${model.id}" on Poe provider. ` +
`Reasoning effort setting "${reasoningEffort}" will not be applied.`
)
return {}
}
let budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
budgetTokens = Math.floor(Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)))
return {
extra_body: {
thinking_budget: budgetTokens
}
}
}
// Gemini models use thinking_budget parameter in extra_body
if (isSupportedThinkingTokenGeminiModel(model)) {
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(model.id)
let budgetTokens: number | undefined
if (tokenLimit && reasoningEffort !== 'auto') {
budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
} else if (!tokenLimit && reasoningEffort !== 'auto') {
logger.warn(
`No token limit configuration found for Gemini model "${model.id}" on Poe provider. ` +
`Using auto (-1) instead of requested effort "${reasoningEffort}".`
)
}
return {
extra_body: {
thinking_budget: budgetTokens ?? -1
}
}
}
// Poe reasoning model not in known categories (GPT-5, Claude, Gemini)
logger.warn(
`Poe provider reasoning model "${model.id}" does not match known categories ` +
`(GPT-5, Claude, Gemini). Reasoning effort setting "${reasoningEffort}" will not be applied.`
)
return {}
}
// OpenRouter models
if (model.provider === SystemProviderIds.openrouter) {
// Grok 4 Fast doesn't support effort levels, always use enabled: true
if (isGrok4FastReasoningModel(model)) {
return {
reasoning: {
enabled: true // Ignore effort level, just enable reasoning
}
}
}
// Other OpenRouter models that support effort levels
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
return {
reasoning: {
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
}
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(modelId)
let budgetTokens: number | undefined
if (tokenLimit) {
budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
}
// nvidia: must use chat_template_kwargs
// Since limited documentation, it's hard to find what parameters should be set
// only part of mainstream oss model covered, all verified by nvidia api
if (model.provider === SystemProviderIds.nvidia) {
if (isSupportedThinkingTokenQwenModel(model)) {
const enableThinkingConfig = isQwenAlwaysThinkModel(model) ? {} : { enable_thinking: true }
return {
chat_template_kwargs: {
...enableThinkingConfig,
thinking_budget: budgetTokens
}
}
} else if (isDeepSeekHybridInferenceModel(model)) {
return { chat_template_kwargs: { thinking: true } }
} else if (isSupportedThinkingTokenKimiModel(model)) {
return { chat_template_kwargs: { thinking: true } }
} else if (isSupportedThinkingTokenZhipuModel(model)) {
return { chat_template_kwargs: { enable_thinking: true } }
}
}
// See https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions
if (model.provider === SystemProviderIds.silicon) {
if (
isDeepSeekHybridInferenceModel(model) ||
isSupportedThinkingTokenZhipuModel(model) ||
isSupportedThinkingTokenQwenModel(model) ||
isSupportedThinkingTokenHunyuanModel(model)
) {
return {
enable_thinking: true,
// Hard-encoded maximum, only for silicon
thinking_budget: budgetTokens ? toInteger(Math.max(budgetTokens, 32768)) : undefined
}
}
return {}
}
// DeepSeek hybrid inference models, v3.1 and maybe more in the future
// 不同的 provider 有不同的思考控制方式,在这里统一解决
if (isDeepSeekHybridInferenceModel(model)) {
if (isSystemProvider(provider)) {
switch (provider.id) {
case SystemProviderIds.dashscope:
return {
enable_thinking: true,
incremental_output: true
}
// TODO: 支持 new-api类型
case SystemProviderIds['new-api']:
case SystemProviderIds.cherryin: {
return {
extra_body: {
thinking: {
type: 'enabled' // auto is invalid
}
}
}
}
case SystemProviderIds.hunyuan:
case SystemProviderIds['tencent-cloud-ti']:
case SystemProviderIds.doubao:
case SystemProviderIds.deepseek:
case SystemProviderIds.aihubmix:
case SystemProviderIds.sophnet:
case SystemProviderIds.ppio:
case SystemProviderIds.dmxapi:
return {
thinking: {
type: 'enabled' // auto is invalid
}
}
case SystemProviderIds.openrouter:
case SystemProviderIds.together:
return {
reasoning: {
enabled: true
}
}
default:
break
}
}
logger.warn(
`Use default thinking options for provider ${provider.name} as DeepSeek v3.1+ thinking control method is unknown`
)
return {
thinking: {
type: 'enabled'
}
}
}
// OpenRouter models, use reasoning
// FIXME: duplicated openrouter handling. remove one
if (model.provider === SystemProviderIds.openrouter) {
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
return {
reasoning: {
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
}
}
}
}
// https://help.aliyun.com/zh/model-studio/deep-thinking
if (provider.id === SystemProviderIds.dashscope) {
// For dashscope: Qwen, DeepSeek, and GLM models use enable_thinking to control thinking
// No effort, only on/off
if (
isQwenReasoningModel(model) ||
isSupportedThinkingTokenZhipuModel(model) ||
isSupportedThinkingTokenKimiModel(model)
) {
return {
enable_thinking: true,
thinking_budget: budgetTokens
}
}
}
// https://docs.together.ai/reference/chat-completions-1#body-reasoning-effort
if (provider.id === SystemProviderIds.together) {
let adjustedReasoningEffort: 'low' | 'medium' | 'high' = 'medium'
switch (reasoningEffort) {
case 'minimal':
adjustedReasoningEffort = 'low'
break
case 'xhigh':
adjustedReasoningEffort = 'high'
break
case 'auto':
adjustedReasoningEffort = 'medium'
break
default:
adjustedReasoningEffort = reasoningEffort
break
}
return {
// Only low, medium, high
reasoningEffort: adjustedReasoningEffort,
reasoning: { enabled: true }
}
}
// Qwen models, use enable_thinking
if (isQwenReasoningModel(model)) {
const supportEnableThinking = isSupportEnableThinkingProvider(provider)
const enableThinkingConfig = isQwenAlwaysThinkModel(model) ? {} : { enable_thinking: true }
if (supportEnableThinking) {
return {
...enableThinkingConfig,
thinking_budget: budgetTokens
}
} else {
return {
chat_template_kwargs: {
...enableThinkingConfig,
thinking_budget: budgetTokens
}
}
}
}
// Hunyuan models, use enable_thinking
if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(provider)) {
return {
enable_thinking: true
}
}
// Grok models/Perplexity models/OpenAI models, use reasoning_effort
if (isSupportedReasoningEffortModel(model)) {
// 检查模型是否支持所选选项
const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default')
if (supportedOptions?.includes(reasoningEffort)) {
return {
reasoningEffort
}
} else {
// 如果不支持fallback到第一个支持的值
return {
reasoningEffort: supportedOptions?.[0]
}
}
}
// gemini series, openai compatible api
if (isSupportedThinkingTokenGeminiModel(model)) {
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
if (isGemini3ThinkingTokenModel(model)) {
return {
reasoningEffort
}
}
if (reasoningEffort === 'auto') {
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: -1,
include_thoughts: true
}
}
}
}
}
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: budgetTokens ?? -1,
include_thoughts: true
}
}
}
}
}
// Claude models, openai compatible api
if (isSupportedThinkingTokenClaudeModel(model)) {
const maxTokens = assistant.settings?.maxTokens
return {
thinking: {
type: 'enabled',
budget_tokens: budgetTokens
? Math.floor(Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)))
: undefined
}
}
}
// Use thinking, doubao, zhipu, etc.
if (isSupportedThinkingTokenDoubaoModel(model)) {
if (isDoubaoSeedAfter251015(model) || isDoubaoSeed18Model(model)) {
return { reasoningEffort }
}
if (reasoningEffort === 'high') {
return { thinking: { type: 'enabled' } }
}
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
return { thinking: { type: 'auto' } }
}
// 其他情况不带 thinking 字段
return {}
}
if (isSupportedThinkingTokenZhipuModel(model)) {
if (provider.id === SystemProviderIds.cerebras) {
return {}
}
return { thinking: { type: 'enabled' } }
}
if (isSupportedThinkingTokenMiMoModel(model) || isSupportedThinkingTokenKimiModel(model)) {
return {
thinking: { type: 'enabled' }
}
}
// Default case: no special thinking settings
return {}
}
/**
* Get OpenAI reasoning parameters
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
* For official OpenAI provider only
*/
export function getOpenAIReasoningParams(
assistant: Assistant,
model: Model
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
if (!isReasoningModel(model)) {
return {}
}
let reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort || reasoningEffort === 'default') {
return {}
}
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
reasoningEffort = 'medium'
}
// 非OpenAI模型但是Provider类型是responses/azure openai的情况
if (!isOpenAIModel(model)) {
return {
reasoningEffort
}
}
const openAI = getStoreSetting('openAI')
const summaryText = openAI.summaryText
let reasoningSummary: OpenAIReasoningSummary = undefined
if (model.id.includes('o1-pro')) {
reasoningSummary = undefined
} else {
reasoningSummary = summaryText
}
// OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoningEffort,
reasoningSummary
}
}
return {}
}
// Conservative fallback token limit for models not in THINKING_TOKEN_MAP.
const FALLBACK_TOKEN_LIMIT = { min: 1024, max: 16384 }
function computeBudgetTokens(
tokenLimit: { min: number; max: number },
effortRatio: number,
maxTokens?: number
): number {
const budget = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
const capped = maxTokens !== undefined ? Math.min(budget, maxTokens) : budget
return Math.max(1024, capped)
}
export function getThinkingBudget(
maxTokens: number | undefined,
reasoningEffort: string | undefined,
modelId: string
): number | undefined {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return undefined
}
const tokenLimit = findTokenLimit(modelId)
if (!tokenLimit) {
return undefined
}
return computeBudgetTokens(tokenLimit, EFFORT_RATIO[reasoningEffort], maxTokens)
}
// Compute a fallback budgetTokens using a conservative token limit when
// findTokenLimit() cannot determine the model's actual limit. This ensures
// { type: 'enabled' } always carries a valid budget, which is required by
// the Claude Agent SDK and the Anthropic Messages API.
function getFallbackBudgetTokens(reasoningEffort: string | undefined): number {
const effortRatio = EFFORT_RATIO[reasoningEffort ?? 'high'] ?? EFFORT_RATIO.high
return computeBudgetTokens(FALLBACK_TOKEN_LIMIT, effortRatio)
}
/**
* Get Anthropic reasoning parameters.
* Extracted from AnthropicAPIClient logic.
*
* Returns different parameter shapes depending on the model:
* - **Claude 4.6**: `{ thinking: { type: 'adaptive' }, effort: 'low' | 'medium' | 'high' | 'max' }`
* Uses the new adaptive thinking API with effort-based control.
* - **Other Claude models** (4.0, 4.1, 4.5, etc.): `{ thinking: { type: 'enabled', budgetTokens: number } }`
* Uses the classic thinking API with explicit token budget.
*/
export function getAnthropicReasoningParams(
assistant: Assistant,
model: Model
): Pick<AnthropicProviderOptions, 'thinking' | 'effort'> {
if (!isReasoningModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort || reasoningEffort === 'default') {
return {}
}
if (reasoningEffort === 'none') {
return {
thinking: {
type: 'disabled'
}
}
}
// Claude reasoning parameters
if (isSupportedThinkingTokenClaudeModel(model)) {
// Claude 4.6 uses adaptive thinking + effort parameters
// Map reasoningEffort to Claude 4.6 supported effort values
if (isClaude46SeriesModel(model)) {
// Claude 4.6 supports: low, medium, high, max
// Mapping rules: default/none -> no effort (uses default high)
// minimal/low -> low
// medium -> medium
// high -> high
// xhigh -> max
const effortMap = {
default: undefined,
auto: undefined,
minimal: 'low',
low: 'low',
medium: 'medium',
high: 'high',
xhigh: 'max'
} as const satisfies Record<Exclude<ReasoningEffortOption, 'none'>, AnthropicProviderOptions['effort']>
const effort = effortMap[reasoningEffort]
return effort ? { thinking: { type: 'adaptive' }, effort } : { thinking: { type: 'adaptive' } }
}
// Other Claude models continue using enabled + budgetTokens
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
thinking: {
type: 'enabled',
budgetTokens: budgetTokens ?? getFallbackBudgetTokens(reasoningEffort)
}
}
} else {
// 其他使用claude端點的模型比如Kimi,Minimax等等
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
// Always include budgetTokens to prevent Claude Agent SDK from converting
// { type: 'enabled' } into '--thinking adaptive', which non-Anthropic
// upstream providers do not support (they only accept 'enabled'/'disabled').
return { thinking: { type: 'enabled', budgetTokens: budgetTokens ?? getFallbackBudgetTokens(reasoningEffort) } }
}
}
type GoogleThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogleThinkingLevel {
switch (reasoningEffort) {
case 'auto':
case 'default':
return undefined
case 'none':
return 'minimal'
case 'minimal':
return 'minimal'
case 'low':
return 'low'
case 'medium':
return 'medium'
case 'high':
case 'xhigh':
return 'high'
default:
// Enforce all possible values are handled
reasoningEffort satisfies never
return undefined
}
}
/**
* 获取 Gemini 推理参数
* 从 GeminiAPIClient 中提取的逻辑
* 注意Gemini/GCP 端点所使用的 thinkingBudget 等参数应该按照驼峰命名法传递
* 而在 Google 官方提供的 OpenAI 兼容端点中则使用蛇形命名法 thinking_budget
*/
export function getGeminiReasoningParams(
assistant: Assistant,
model: Model
): Pick<GoogleGenerativeAIProviderOptions, 'thinkingConfig'> {
if (!isReasoningModel(model) || !isSupportedThinkingTokenGeminiModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort || reasoningEffort === 'default') {
return {}
}
let thinkingLevel: GoogleThinkingLevel | null = null
const includeThoughts = reasoningEffort !== 'none'
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
if (isGemini3ThinkingTokenModel(model)) {
thinkingLevel = mapToGeminiThinkingLevel(reasoningEffort)
if (thinkingLevel === 'minimal' && getLowerBaseModelName(model.id).includes('pro')) {
thinkingLevel = 'low'
}
}
if (thinkingLevel !== null) {
// Gemini 3 branch. thinkingLevel can be undefined (auto) or a specific level.
return {
thinkingConfig: {
includeThoughts,
thinkingLevel
}
}
} else {
// Old models
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (reasoningEffort === 'auto') {
return {
thinkingConfig: {
includeThoughts,
thinkingBudget: -1
}
}
}
if (reasoningEffort === 'none') {
return {
thinkingConfig: {
includeThoughts,
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
}
}
}
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
const budget = Math.floor((max - min) * effortRatio + min)
return {
thinkingConfig: {
includeThoughts,
...(budget > 0 ? { thinkingBudget: budget } : {})
}
}
}
}
/**
* Get XAI-specific reasoning parameters
* This function should only be called for XAI provider models
* @param assistant - The assistant configuration
* @param model - The model being used
* @returns XAI-specific reasoning parameters
*/
export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick<XaiProviderOptions, 'reasoningEffort'> {
if (!isSupportedReasoningEffortGrokModel(model)) {
return {}
}
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
switch (reasoningEffort) {
case 'auto':
case 'minimal':
case 'medium':
return { reasoningEffort: 'low' }
case 'low':
case 'high':
return { reasoningEffort }
case 'xhigh':
return { reasoningEffort: 'high' }
case 'default':
case 'none':
default:
return {}
}
}
/**
* Get Bedrock reasoning parameters
*/
export function getBedrockReasoningParams(
assistant: Assistant,
model: Model
): Pick<BedrockProviderOptions, 'reasoningConfig'> {
if (!isReasoningModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined || reasoningEffort === 'default') {
return {}
}
if (reasoningEffort === 'none') {
return {
reasoningConfig: {
type: 'disabled'
}
}
}
// Only apply thinking budget for Claude reasoning models
if (!isSupportedThinkingTokenClaudeModel(model)) {
return {}
}
// Claude 4.6 uses adaptive thinking + maxReasoningEffort
if (isClaude46SeriesModel(model)) {
const effortMap = {
auto: undefined,
minimal: 'low',
low: 'low',
medium: 'medium',
high: 'high',
xhigh: 'max'
} as const satisfies Record<
Exclude<ReasoningEffortOption, 'none' | 'default'>,
NonNullable<BedrockProviderOptions['reasoningConfig']>['maxReasoningEffort']
>
const maxReasoningEffort = effortMap[reasoningEffort]
return maxReasoningEffort
? { reasoningConfig: { type: 'adaptive', maxReasoningEffort } }
: { reasoningConfig: { type: 'adaptive' } }
}
// Other Claude models use enabled + budgetTokens
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
reasoningConfig: {
type: 'enabled',
budgetTokens: budgetTokens
}
}
}
/**
* Get Ollama reasoning parameters
* Handles the `think` parameter for Ollama models
*
* - GPT-OSS models: accept 'low' | 'medium' | 'high' string values
* - Other models: boolean only (true/false)
*/
export function getOllamaReasoningParams(assistant: Assistant, model: Model): Pick<OllamaProviderOptions, 'think'> {
const reasoningEffort = assistant.settings?.reasoning_effort
if (isOpenAIOpenWeightModel(model)) {
// gpt-oss models accept 'low' | 'medium' | 'high' string values
if (reasoningEffort === 'low' || reasoningEffort === 'medium' || reasoningEffort === 'high') {
return { think: reasoningEffort }
} else if (reasoningEffort === 'none') {
return { think: false }
}
return { think: true }
}
// Other models: boolean only. undefined defaults to true (user enabled reasoning)
return { think: reasoningEffort !== 'none' }
}
/**
* 获取自定义参数
* 从 assistant 设置中提取自定义参数
*/
export function getCustomParameters(assistant: Assistant): Record<string, any> {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
// Parse JSON type parameters
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
// The UI stores JSON type params as strings (e.g., '{"key":"value"}')
// This function parses them into objects before sending to the API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
try {
return { ...acc, [param.name]: JSON.parse(value) }
} catch {
return { ...acc, [param.name]: value }
}
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}
/**
* Get reasoning tag name based on model ID
* Used for extractReasoningMiddleware configuration
*/
export function getReasoningTagName(modelId: string | undefined): string {
const tagName = {
reasoning: 'reasoning',
think: 'think',
thought: 'thought',
seedThink: 'seed:think'
}
if (modelId?.includes('gpt-oss')) return tagName.reasoning
if (modelId?.includes('gemini')) return tagName.thought
if (modelId?.includes('seed-oss-36b')) return tagName.seedThink
return tagName.think
}

View File

@@ -1,125 +0,0 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
import type { AppProviderId } from '@renderer/aiCore/types'
import { isOpenAIDeepResearchModel, isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
export function getWebSearchParams(model: Model): Record<string, any> {
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
// https://creator.poe.com/docs/external-applications/openai-compatible-api#using-custom-parameters-with-extra_body
if (model.provider === 'poe') {
return {
extra_body: {
web_search: true
}
}
}
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
return {}
}
/**
* range in [0, 100]
* @param maxResults
*/
function mapMaxResultToOpenAIContextSize(
maxResults: number
): NonNullable<WebSearchPluginConfig['openai']>['searchContextSize'] {
if (maxResults <= 33) return 'low'
if (maxResults <= 66) return 'medium'
return 'high'
}
export function buildProviderBuiltinWebSearchConfig(
providerId: AppProviderId,
webSearchConfig: CherryWebSearchConfig,
model?: Model
): WebSearchPluginConfig | undefined {
switch (providerId) {
case 'azure-responses':
case 'openai': {
const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium'
: mapMaxResultToOpenAIContextSize(webSearchConfig.maxResults)
return {
openai: {
searchContextSize
}
}
}
case 'openai-chat': {
const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium'
: mapMaxResultToOpenAIContextSize(webSearchConfig.maxResults)
return {
'openai-chat': {
searchContextSize
}
}
}
case 'anthropic': {
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const anthropicSearchOptions: NonNullable<WebSearchPluginConfig['anthropic']> = {
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}
return {
anthropic: anthropicSearchOptions
}
}
case 'xai':
case 'xai-responses': {
const excludeDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const xaiWebConfig: NonNullable<NonNullable<WebSearchPluginConfig['xai-responses']>['webSearch']> = {
enableImageUnderstanding: true
}
if (excludeDomains.length > 0) {
xaiWebConfig.excludedDomains = excludeDomains.slice(0, 5)
}
return {
'xai-responses': {
webSearch: xaiWebConfig,
xSearch: { enableImageUnderstanding: true }
}
}
}
case 'openrouter': {
return {
openrouter: {
plugins: [
{
id: 'web',
max_results: webSearchConfig.maxResults
}
]
}
}
}
case 'cherryin': {
const _providerId =
{ 'openai-response': 'openai', openai: 'openai-chat' }[model?.endpoint_type ?? ''] ?? model?.endpoint_type
return buildProviderBuiltinWebSearchConfig(_providerId, webSearchConfig, model)
}
default: {
return {}
}
}
}

View File

@@ -179,6 +179,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
try {
if (mode === 'aihubmix_image_generate') {
// TODO(renderer/aiCore-cleanup): the remaining Gemini/Ideogram/custom fetch branches should move behind main AI/image IPC so this page no longer owns provider-specific transport logic.
if (painting.model.startsWith('imagen-')) {
const requestId = uuid()
activeImageRequestIdRef.current = requestId

View File

@@ -1,3 +1,4 @@
/** TODO(renderer/aiCore-cleanup): replace these temporary mirrored tool response types with shared/main-owned contracts once knowledge/web/memory tools stop depending on legacy aiCore definitions. */
export interface KnowledgeSearchToolInput {
additionalContext?: string
}

View File

@@ -11,6 +11,7 @@ import {
isVertexProvider
} from '@renderer/utils/provider'
/** TODO(renderer/aiCore-cleanup): converge this host normalization helper with the remaining main/provider config formatter so we can delete the old aiCore provider config copy safely. */
interface HostFormatter {
match: (provider: Provider) => boolean
format: (provider: Provider, appendApiVersion: boolean) => string

View File

@@ -1,6 +1,7 @@
import { findTokenLimit } from '@renderer/config/models'
import { EFFORT_RATIO } from '@renderer/types'
/** TODO(renderer/aiCore-cleanup): only `getThinkingBudget` is extracted here. Migrate or delete the remaining reasoning helpers from the old renderer aiCore module after code/CLI flows are fully decoupled. */
const FALLBACK_TOKEN_LIMIT = { min: 1024, max: 16384 }
function computeBudgetTokens(