mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-07-04 21:15:35 +08:00
refactor(renderer): remove legacy aiCore layer
Signed-off-by: suyao <sy20010504@gmail.com>
This commit is contained in:
@@ -82,6 +82,7 @@ export interface AiImageRequest extends AiBaseRequest {
|
||||
numInferenceSteps?: number
|
||||
guidanceScale?: number
|
||||
promptEnhancement?: boolean
|
||||
/** TODO(renderer/aiCore-cleanup): wire personGeneration through to the underlying image runtime once the main image contract formally supports it end-to-end. */
|
||||
personGeneration?: string
|
||||
}
|
||||
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
# Cherry Studio AI Provider 技术架构文档 (新方案)
|
||||
|
||||
## 1. 核心设计理念与目标
|
||||
|
||||
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
|
||||
|
||||
- **职责清晰**:明确划分各组件的职责,降低耦合度。
|
||||
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
|
||||
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
|
||||
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
|
||||
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
|
||||
|
||||
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
|
||||
|
||||
## 2. 核心组件详解
|
||||
|
||||
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
|
||||
|
||||
这是整个 AI 功能的核心模块。
|
||||
|
||||
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
|
||||
|
||||
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
|
||||
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
|
||||
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
|
||||
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
|
||||
- **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>` 或 `<tool_use>` 标签。
|
||||
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
|
||||
|
||||
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
|
||||
|
||||
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
|
||||
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
|
||||
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
|
||||
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
|
||||
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
|
||||
|
||||
#### 2.1.3. `ApiClientFactory.ts`
|
||||
|
||||
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
|
||||
|
||||
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
|
||||
|
||||
- **职责**:作为所有 AI 相关业务功能的统一入口。
|
||||
- 提供面向应用的高层接口,例如:
|
||||
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
|
||||
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
|
||||
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
|
||||
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
|
||||
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。
|
||||
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
|
||||
- **封装特定任务的提示工程 (Prompt Engineering)**:
|
||||
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。
|
||||
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
|
||||
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
|
||||
- **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。
|
||||
- **优势**:
|
||||
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
|
||||
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
|
||||
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
|
||||
|
||||
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
|
||||
|
||||
- 定义核心的、Provider 无关的内部请求结构,例如:
|
||||
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
|
||||
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
|
||||
|
||||
### 2.2. `middleware`
|
||||
|
||||
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
|
||||
|
||||
**核心组件包括:**
|
||||
|
||||
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
|
||||
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
|
||||
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
|
||||
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
|
||||
|
||||
#### 2.2.1. `middlewareTypes.ts`
|
||||
|
||||
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。
|
||||
|
||||
#### 2.2.2. 核心中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`。
|
||||
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。
|
||||
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`。
|
||||
|
||||
#### 2.2.3. 特性中间件 (`middleware/feat/`)
|
||||
|
||||
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
|
||||
|
||||
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。
|
||||
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
|
||||
|
||||
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
|
||||
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
|
||||
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
|
||||
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
|
||||
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
|
||||
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
|
||||
|
||||
#### 2.2.5. 通用中间件 (`middleware/common/`)
|
||||
|
||||
- **`LoggingMiddleware.ts`**: 请求和响应日志。
|
||||
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
|
||||
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
|
||||
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
|
||||
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
|
||||
- 在流结束时,发送包含最终累加信息的完成信号。
|
||||
|
||||
### 2.3. `types/chunk.ts`
|
||||
|
||||
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
|
||||
|
||||
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
|
||||
|
||||
```markdown
|
||||
**应用层 (例如 UI 组件)**
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` (`aiCore/index.ts`)**
|
||||
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
|
||||
||
|
||||
\\/
|
||||
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
|
||||
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
|
||||
||
|
||||
\\/
|
||||
**[ 预处理阶段中间件 ]**
|
||||
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|
||||
|| (Context 中准备好 SDK 请求参数)
|
||||
\\/
|
||||
**[ 处理阶段中间件 ]**
|
||||
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|
||||
|| (处理各种特性和Chunk类型)
|
||||
\\/
|
||||
**[ SDK调用阶段中间件 ]**
|
||||
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|
||||
|| (输出: 标准化的应用层Chunk流)
|
||||
\\/
|
||||
**`FinalChunkConsumerMiddleware` (核心)**
|
||||
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
|
||||
```
|
||||
|
||||
## 4. 建议的文件/目录结构
|
||||
|
||||
```
|
||||
src/renderer/src/
|
||||
└── aiCore/
|
||||
├── clients/
|
||||
│ ├── openai/
|
||||
│ ├── gemini/
|
||||
│ ├── anthropic/
|
||||
│ ├── BaseApiClient.ts
|
||||
│ ├── ApiClientFactory.ts
|
||||
│ ├── AihubmixAPIClient.ts
|
||||
│ ├── index.ts
|
||||
│ └── types.ts
|
||||
├── middleware/
|
||||
│ ├── common/
|
||||
│ ├── core/
|
||||
│ ├── feat/
|
||||
│ ├── builder.ts
|
||||
│ ├── composer.ts
|
||||
│ ├── index.ts
|
||||
│ ├── register.ts
|
||||
│ ├── schemas.ts
|
||||
│ ├── types.ts
|
||||
│ └── utils.ts
|
||||
├── types/
|
||||
│ ├── chunk.ts
|
||||
│ └── ...
|
||||
└── index.ts
|
||||
```
|
||||
|
||||
## 5. 迁移和实施建议
|
||||
|
||||
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
|
||||
- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
|
||||
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
|
||||
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
|
||||
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
|
||||
|
||||
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
|
||||
|
||||
## 6. 迁移策略与实施建议
|
||||
|
||||
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
|
||||
|
||||
**目标架构核心组件回顾:**
|
||||
|
||||
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
|
||||
|
||||
**迁移步骤:**
|
||||
|
||||
**Phase 0: 准备工作和类型定义**
|
||||
|
||||
1. **定义核心数据结构 (TypeScript 类型):**
|
||||
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
|
||||
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
|
||||
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
|
||||
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
|
||||
3. **调整 `AiProviderMiddlewareContext`:**
|
||||
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`。
|
||||
- 确保包含 `_coreRequest: CoreRequestType`。
|
||||
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
|
||||
|
||||
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
|
||||
|
||||
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
|
||||
2. **迁移SDK实例和配置。**
|
||||
3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
|
||||
4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
|
||||
@@ -1,515 +0,0 @@
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import type { generateImageResult } from '@cherrystudio/ai-core/core/runtime/types'
|
||||
import { cacheService } from '@data/CacheService'
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import { normalizeGatewayModels } from '@renderer/services/models/ModelAdapter'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import {
|
||||
type Assistant,
|
||||
type EditImageParams,
|
||||
type GenerateImageParams,
|
||||
type Model,
|
||||
type Provider,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
|
||||
import { gateway } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
|
||||
import { listModels } from './services/listModels'
|
||||
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export type AiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
|
||||
export default class AiProvider {
|
||||
private config?: ProviderConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
|
||||
/**
|
||||
* Constructor for AiProvider
|
||||
*
|
||||
* @param modelOrProvider - Model or Provider object
|
||||
* @param provider - Optional Provider object (only used when first param is Model)
|
||||
*
|
||||
* @remarks
|
||||
* **Important behavior notes**:
|
||||
*
|
||||
* 1. When called with `(model)`:
|
||||
* - Calls `getActualProvider(model)` to retrieve and format the provider
|
||||
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
|
||||
*
|
||||
* 2. When called with `(model, provider)`:
|
||||
* - The provided provider will be adapted via `adaptProvider`
|
||||
* - URL formatting behavior depends on the adapted result
|
||||
*
|
||||
* 3. When called with `(provider)`:
|
||||
* - The provider will be adapted via `adaptProvider`
|
||||
* - Used for operations that don't need a model (e.g., fetchModels)
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Recommended: Auto-format URL
|
||||
* const ai = new AiProvider(model)
|
||||
*
|
||||
* // Provider will be adapted
|
||||
* const ai = new AiProvider(model, customProvider)
|
||||
*
|
||||
* // For operations that don't need a model
|
||||
* const ai = new AiProvider(provider)
|
||||
* ```
|
||||
*/
|
||||
constructor(model: Model, provider?: Provider)
|
||||
constructor(provider: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
|
||||
if (this.isModel(modelOrProvider)) {
|
||||
// 传入的是 Model
|
||||
this.model = modelOrProvider
|
||||
this.actualProvider = provider
|
||||
? adaptProvider({ provider, model: modelOrProvider })
|
||||
: getActualProvider(modelOrProvider)
|
||||
// 注意:config 可能是同步值或 Promise,在 completions() 中会统一处理
|
||||
const configOrPromise = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||
this.config = configOrPromise instanceof Promise ? undefined : configOrPromise
|
||||
} else {
|
||||
// 传入的是 Provider
|
||||
this.actualProvider = adaptProvider({ provider: modelOrProvider })
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
|
||||
*/
|
||||
private isModel(obj: Model | Provider): obj is Model {
|
||||
return 'provider' in obj && typeof obj.provider === 'string'
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
public async completions(modelId: string, params: StreamTextParams, middlewareConfig: AiProviderConfig) {
|
||||
// 检查model是否存在
|
||||
if (!this.model) {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// Config is now set in constructor, ApiService handles key rotation before passing provider
|
||||
if (!this.config) {
|
||||
// If config wasn't set in constructor (when provider only), generate it now
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, this.model))
|
||||
}
|
||||
logger.debug('Using provider config for completions', this.config)
|
||||
|
||||
// 注意:模型对象将由 createExecutor 内部处理,不再需要预先创建
|
||||
|
||||
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
|
||||
// 类型守卫:确保 system 是 string、Array 或 undefined
|
||||
const system = params.system
|
||||
let systemParam: string | Array<any> | undefined
|
||||
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
|
||||
systemParam = system
|
||||
} else {
|
||||
// SystemModelMessage 类型,转换为 string
|
||||
systemParam = undefined
|
||||
}
|
||||
|
||||
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
|
||||
params.system = undefined // 清除原有system,避免重复
|
||||
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
|
||||
}
|
||||
|
||||
if (middlewareConfig.topicId && (await preferenceService.get('app.developer_mode.enabled'))) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...middlewareConfig,
|
||||
topicId: middlewareConfig.topicId
|
||||
}
|
||||
return await this._completionsForTrace(modelId, params, traceConfig, this.config)
|
||||
} else {
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig, this.config)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
private async _completionsForTrace(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiProviderConfig & { topicId: string },
|
||||
providerConfig: ProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${middlewareConfig.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelName: middlewareConfig.assistant.model?.name, // 使用modelId而不是provider名称
|
||||
inputs: params
|
||||
}
|
||||
|
||||
logger.info('Starting AI SDK trace span', {
|
||||
traceName,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : []
|
||||
})
|
||||
|
||||
const span = await addSpan(traceParams)
|
||||
if (!span) {
|
||||
logger.warn('Failed to create span, falling back to regular completions', {
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Created parent span, now calling completions', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
resultLength: result.getText().length
|
||||
})
|
||||
|
||||
// 标记span完成
|
||||
endSpan({
|
||||
topicId: middlewareConfig.topicId,
|
||||
outputs: result,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId
|
||||
})
|
||||
|
||||
// 标记span出错
|
||||
endSpan({
|
||||
topicId: middlewareConfig.topicId,
|
||||
error: error as Error,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
/**
|
||||
* Note: This implementation always uses `executor.streamText` and never
|
||||
* calls `generateText`, even when `onChunk` is not provided.
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiProviderConfig,
|
||||
providerConfig: ProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const plugins = await buildPlugins({
|
||||
provider: this.actualProvider,
|
||||
model: this.model!,
|
||||
config: middlewareConfig
|
||||
})
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
plugins
|
||||
)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (middlewareConfig.onChunk) {
|
||||
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(
|
||||
middlewareConfig.onChunk,
|
||||
middlewareConfig.mcpTools,
|
||||
accumulate,
|
||||
middlewareConfig.enableWebSearch,
|
||||
undefined,
|
||||
undefined,
|
||||
providerConfig.providerId
|
||||
)
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model: modelId,
|
||||
experimental_context: { onChunk: middlewareConfig.onChunk }
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// Since no onChunk is provided, the external consumer would not handle error chunk.
|
||||
// So we need to capture the actual stream error so we can throw it instead of the
|
||||
// generic NoTextGeneratedError ("No output generated. Check the stream
|
||||
// for errors.") that AI SDK raises when streamResult.text is accessed
|
||||
// after a failed stream.
|
||||
let streamError: unknown = undefined
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model: modelId,
|
||||
onError({ error }) {
|
||||
streamError = error
|
||||
}
|
||||
})
|
||||
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream({
|
||||
onError(error) {
|
||||
if (!streamError) {
|
||||
streamError = error
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
try {
|
||||
const finalText = await streamResult.text
|
||||
const usage = await streamResult.totalUsage
|
||||
|
||||
return {
|
||||
getText: () => finalText,
|
||||
usage
|
||||
}
|
||||
} catch (error) {
|
||||
// If we captured the real stream error, throw that instead of the
|
||||
// generic NoTextGeneratedError so callers get actionable diagnostics.
|
||||
if (streamError) {
|
||||
throw streamError
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型列表
|
||||
* 使用 ModelListService 统一处理各 Provider 的模型列表获取
|
||||
*/
|
||||
public async models(): Promise<Model[]> {
|
||||
// Gateway provider 使用 AI SDK 的 gateway API
|
||||
if (this.actualProvider.id === SystemProviderIds.gateway) {
|
||||
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||
}
|
||||
|
||||
// 使用新的 ModelListService
|
||||
return await listModels(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取嵌入模型的维度
|
||||
* 使用 AI SDK embedMany 测试获取维度
|
||||
*/
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
|
||||
}
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
this.config.providerId,
|
||||
this.config.providerSettings,
|
||||
[]
|
||||
)
|
||||
|
||||
// 使用 AI SDK embedMany 测试获取维度
|
||||
const result = await executor.embedMany({
|
||||
model: model.id,
|
||||
values: ['test']
|
||||
})
|
||||
|
||||
return result.embeddings[0].length
|
||||
}
|
||||
|
||||
/**
|
||||
* 懒加载初始化 config
|
||||
* 当 constructor 只传入 provider 时,config 不会被初始化
|
||||
* 此方法根据 modelId 从 provider 的 models 中查找真实 Model 并生成 config
|
||||
*/
|
||||
private async ensureConfig(modelId: string): Promise<void> {
|
||||
if (this.config) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从 provider 的 models 中查找真实的 model
|
||||
const model = this.actualProvider.models.find((m) => getLowerBaseModelName(m.id) === getLowerBaseModelName(modelId))
|
||||
if (!model) {
|
||||
throw new Error(`Model "${modelId}" not found in provider "${this.actualProvider.id}"`)
|
||||
}
|
||||
|
||||
this.actualProvider = adaptProvider({ provider: this.actualProvider, model })
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
* 使用现代化 AI SDK 实现,不再 fallback 到 legacy
|
||||
*/
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
await this.ensureConfig(params.model)
|
||||
return await this.modernGenerateImage(params, this.config!)
|
||||
}
|
||||
|
||||
/**
|
||||
* 编辑图像 - 基于输入图像和文本提示生成新图像
|
||||
* 内部使用 AI SDK 的 generateImage,通过 prompt.images 参数实现编辑功能
|
||||
*/
|
||||
public async editImage(params: EditImageParams): Promise<string[]> {
|
||||
await this.ensureConfig(params.model)
|
||||
return await this.modernEditImage(params, this.config!)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams, providerConfig: ProviderConfig): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
[]
|
||||
)
|
||||
const result = await executor.generateImage({
|
||||
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
|
||||
...aiSdkParams
|
||||
})
|
||||
|
||||
return this.convertImageResult(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像编辑实现
|
||||
* 通过 AI SDK 的 generateImage 并传入 prompt.images 参数实现编辑功能
|
||||
*/
|
||||
private async modernEditImage(params: EditImageParams, providerConfig: ProviderConfig): Promise<string[]> {
|
||||
const { model, prompt, inputImages, mask, imageSize, signal } = params
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
[]
|
||||
)
|
||||
|
||||
// 使用 AI SDK 的 generateImage,通过 prompt.images 实现编辑
|
||||
const result = await executor.generateImage({
|
||||
model: model,
|
||||
prompt: {
|
||||
text: prompt,
|
||||
images: inputImages, // 输入图像(必需)
|
||||
...(mask && { mask }) // 可选的 mask(用于 inpainting)
|
||||
},
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
...(signal && { abortSignal: signal })
|
||||
})
|
||||
|
||||
return this.convertImageResult(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换图像生成结果格式
|
||||
*/
|
||||
private convertImageResult(result: generateImageResult): string[] {
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if (image.base64) {
|
||||
images.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.actualProvider.apiHost || ''
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
const apiKey = this.actualProvider.apiKey
|
||||
if (!apiKey || apiKey.trim() === '') {
|
||||
return ''
|
||||
}
|
||||
|
||||
const keys = apiKey
|
||||
.split(',')
|
||||
.map((key) => key.trim())
|
||||
.filter(Boolean)
|
||||
|
||||
if (keys.length === 0) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
// Multi-key rotation
|
||||
const keyName = `provider:${this.actualProvider.id}:last_used_key`
|
||||
const lastUsedKey = cacheService.getCasual<string>(keyName)
|
||||
|
||||
if (!lastUsedKey) {
|
||||
cacheService.setCasual(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
cacheService.setCasual(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
}
|
||||
@@ -1,475 +0,0 @@
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { AISDKWebSearchResult, MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { WEB_SEARCH_SOURCE } from '@renderer/types'
|
||||
import type { Chunk, ProviderMetadata } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { ProviderSpecificError } from '@renderer/types/provider-specific-error'
|
||||
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
|
||||
import type { IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
|
||||
import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter'
|
||||
import type { ClaudeCodeRawValue } from '@shared/agents/claudecode/types'
|
||||
import { AISDKError, type TextStreamPart, type ToolSet } from 'ai'
|
||||
|
||||
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToChunkAdapter')
|
||||
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器类
|
||||
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
private accumulate: boolean | undefined
|
||||
private isFirstChunk = true
|
||||
private enableWebSearch: boolean = false
|
||||
private onSessionUpdate?: (sessionId: string) => void
|
||||
private responseStartTimestamp: number | null = null
|
||||
private firstTokenTimestamp: number | null = null
|
||||
private hasTextContent = false
|
||||
private getSessionWasCleared?: () => boolean
|
||||
private providerId?: string
|
||||
private idleTimeout?: IdleTimeoutHandle
|
||||
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
mcpTools: MCPTool[] = [],
|
||||
accumulate?: boolean,
|
||||
enableWebSearch?: boolean,
|
||||
onSessionUpdate?: (sessionId: string) => void,
|
||||
getSessionWasCleared?: () => boolean,
|
||||
providerId?: string,
|
||||
idleTimeout?: IdleTimeoutHandle
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
this.accumulate = accumulate
|
||||
this.enableWebSearch = enableWebSearch || false
|
||||
this.onSessionUpdate = onSessionUpdate
|
||||
this.getSessionWasCleared = getSessionWasCleared
|
||||
this.providerId = providerId
|
||||
this.idleTimeout = idleTimeout
|
||||
}
|
||||
|
||||
private markFirstTokenIfNeeded() {
|
||||
if (this.firstTokenTimestamp === null && this.responseStartTimestamp !== null) {
|
||||
this.firstTokenTimestamp = Date.now()
|
||||
}
|
||||
}
|
||||
|
||||
private resetTimingState() {
|
||||
this.responseStartTimestamp = null
|
||||
this.firstTokenTimestamp = null
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 AI SDK 流结果
|
||||
* @param aiSdkResult AI SDK 的流结果对象
|
||||
* @returns 最终的文本内容
|
||||
*/
|
||||
async processStream(aiSdkResult: any): Promise<string> {
|
||||
// The stream is the single source of truth for abort handling.
|
||||
// Both AI SDK (resilient stream) and the agent pipeline (withAbortStreamPart)
|
||||
// guarantee: abort → enqueue { type: 'abort' } → close gracefully.
|
||||
// convertAndEmitChunk processes the abort part and emits ChunkType.ERROR → onError.
|
||||
if (aiSdkResult.fullStream) {
|
||||
await this.readFullStream(aiSdkResult.fullStream)
|
||||
}
|
||||
|
||||
try {
|
||||
return await aiSdkResult.text
|
||||
} catch (error: any) {
|
||||
// The text promise rejects when no steps completed (e.g. abort during thinking).
|
||||
// The abort was already handled via the 'abort' stream part above.
|
||||
if (isAbortError(error)) {
|
||||
return ''
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||
*/
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
|
||||
const reader = fullStream.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
reasoningContent: '',
|
||||
webSearchResults: [],
|
||||
reasoningId: '',
|
||||
providerMetadata: undefined as ProviderMetadata | undefined
|
||||
}
|
||||
this.resetTimingState()
|
||||
this.responseStartTimestamp = Date.now()
|
||||
// Reset state at the start of stream
|
||||
this.isFirstChunk = true
|
||||
this.hasTextContent = false
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
// Reset idle timeout on every chunk received from the stream
|
||||
this.idleTimeout?.reset()
|
||||
|
||||
if (done) {
|
||||
// Flush any remaining content from link converter buffer if web search is enabled
|
||||
if (this.enableWebSearch) {
|
||||
const remainingText = flushLinkConverterBuffer()
|
||||
if (remainingText) {
|
||||
this.markFirstTokenIfNeeded()
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: remainingText
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 转换并发送 chunk
|
||||
this.convertAndEmitChunk(value, final)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
this.resetTimingState()
|
||||
// Clean up the idle timeout timer when the stream ends
|
||||
this.idleTimeout?.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 如果有累积的思考内容,发送 THINKING_COMPLETE chunk 并清空
|
||||
* @param final 包含 reasoningContent 的状态对象
|
||||
* @returns 是否发送了 THINKING_COMPLETE chunk
|
||||
*/
|
||||
private emitThinkingCompleteIfNeeded(final: { reasoningContent: string; [key: string]: any }) {
|
||||
if (final.reasoningContent) {
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: final.reasoningContent
|
||||
})
|
||||
final.reasoningContent = ''
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
private convertAndEmitChunk(
|
||||
chunk: TextStreamPart<any>,
|
||||
final: {
|
||||
text: string
|
||||
reasoningContent: string
|
||||
webSearchResults: AISDKWebSearchResult[]
|
||||
reasoningId: string
|
||||
providerMetadata: ProviderMetadata | undefined
|
||||
}
|
||||
) {
|
||||
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
switch (chunk.type) {
|
||||
case 'raw': {
|
||||
const agentRawMessage = chunk.rawValue as ClaudeCodeRawValue
|
||||
if (agentRawMessage.type === 'init' && agentRawMessage.session_id) {
|
||||
this.onSessionUpdate?.(agentRawMessage.session_id)
|
||||
} else if (agentRawMessage.type === 'compact' && agentRawMessage.session_id) {
|
||||
this.onSessionUpdate?.(agentRawMessage.session_id)
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.RAW,
|
||||
content: agentRawMessage
|
||||
})
|
||||
break
|
||||
}
|
||||
// === 文本相关事件 ===
|
||||
case 'text-start':
|
||||
// 如果有未完成的思考内容,先生成 THINKING_COMPLETE
|
||||
// 这处理了某些提供商不发送 reasoning-end 事件的情况
|
||||
this.emitThinkingCompleteIfNeeded(final)
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
break
|
||||
case 'text-delta': {
|
||||
this.hasTextContent = true
|
||||
const processedText = chunk.text || ''
|
||||
let finalText: string
|
||||
|
||||
// Only apply link conversion if web search is enabled
|
||||
if (this.enableWebSearch) {
|
||||
const result = convertLinks(processedText, this.isFirstChunk)
|
||||
|
||||
if (this.isFirstChunk) {
|
||||
this.isFirstChunk = false
|
||||
}
|
||||
|
||||
// Handle buffered content
|
||||
if (result.hasBufferedContent) {
|
||||
finalText = result.text
|
||||
} else {
|
||||
finalText = result.text || processedText
|
||||
}
|
||||
} else {
|
||||
// Without web search, just use the original text
|
||||
finalText = processedText
|
||||
}
|
||||
|
||||
if (this.accumulate) {
|
||||
final.text += finalText
|
||||
} else {
|
||||
final.text = finalText
|
||||
}
|
||||
|
||||
// Extract thoughtSignature from providerMetadata.google and preserve it
|
||||
const newSignature = chunk.providerMetadata?.google?.thoughtSignature as string | undefined
|
||||
if (newSignature) {
|
||||
final.providerMetadata = {
|
||||
...final.providerMetadata,
|
||||
google: {
|
||||
...final.providerMetadata?.google,
|
||||
thoughtSignature: newSignature
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit chunk if there's text to send
|
||||
if (finalText) {
|
||||
this.markFirstTokenIfNeeded()
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: this.accumulate ? final.text : finalText,
|
||||
providerMetadata: final.providerMetadata
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? '',
|
||||
providerMetadata: final.providerMetadata
|
||||
})
|
||||
final.text = ''
|
||||
// Clear providerMetadata for next text block
|
||||
final.providerMetadata = undefined
|
||||
break
|
||||
case 'reasoning-start':
|
||||
// if (final.reasoningId !== chunk.id) {
|
||||
final.reasoningId = chunk.id
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
// }
|
||||
break
|
||||
case 'reasoning-delta':
|
||||
final.reasoningContent += chunk.text || ''
|
||||
if (chunk.text) {
|
||||
this.markFirstTokenIfNeeded()
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: final.reasoningContent || ''
|
||||
})
|
||||
break
|
||||
case 'reasoning-end':
|
||||
this.emitThinkingCompleteIfNeeded(final)
|
||||
break
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
|
||||
case 'tool-input-start':
|
||||
this.toolCallHandler.handleToolInputStart(chunk)
|
||||
break
|
||||
case 'tool-input-delta':
|
||||
this.toolCallHandler.handleToolInputDelta(chunk)
|
||||
break
|
||||
case 'tool-input-end':
|
||||
this.toolCallHandler.handleToolInputEnd(chunk)
|
||||
break
|
||||
|
||||
case 'tool-call':
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-error':
|
||||
this.toolCallHandler.handleToolError(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
case 'finish-step': {
|
||||
const { providerMetadata, finishReason } = chunk
|
||||
// googel web search
|
||||
if (providerMetadata?.google?.groundingMetadata) {
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
|
||||
source: WEB_SEARCH_SOURCE.GEMINI
|
||||
}
|
||||
})
|
||||
} else if (final.webSearchResults.length) {
|
||||
const providerName: string | undefined = Object.keys(providerMetadata || {})[0] || this.providerId
|
||||
const sourceMap: Record<string, WebSearchSource> = {
|
||||
[WEB_SEARCH_SOURCE.OPENAI]: WEB_SEARCH_SOURCE.OPENAI_RESPONSE,
|
||||
[WEB_SEARCH_SOURCE.ANTHROPIC]: WEB_SEARCH_SOURCE.ANTHROPIC,
|
||||
[WEB_SEARCH_SOURCE.OPENROUTER]: WEB_SEARCH_SOURCE.OPENROUTER,
|
||||
[WEB_SEARCH_SOURCE.GEMINI]: WEB_SEARCH_SOURCE.GEMINI,
|
||||
// [WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
|
||||
[WEB_SEARCH_SOURCE.QWEN]: WEB_SEARCH_SOURCE.QWEN,
|
||||
[WEB_SEARCH_SOURCE.HUNYUAN]: WEB_SEARCH_SOURCE.HUNYUAN,
|
||||
[WEB_SEARCH_SOURCE.ZHIPU]: WEB_SEARCH_SOURCE.ZHIPU,
|
||||
[WEB_SEARCH_SOURCE.GROK]: WEB_SEARCH_SOURCE.GROK,
|
||||
xai: WEB_SEARCH_SOURCE.GROK,
|
||||
[WEB_SEARCH_SOURCE.WEBSEARCH]: WEB_SEARCH_SOURCE.WEBSEARCH
|
||||
}
|
||||
const source = (providerName && sourceMap[providerName]) || WEB_SEARCH_SOURCE.AISDK
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: final.webSearchResults,
|
||||
source
|
||||
}
|
||||
})
|
||||
}
|
||||
if (finishReason === 'tool-calls') {
|
||||
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
}
|
||||
|
||||
final.webSearchResults = []
|
||||
// final.reasoningId = ''
|
||||
break
|
||||
}
|
||||
|
||||
case 'finish': {
|
||||
// Check if session was cleared (e.g., /clear command) and no text was output
|
||||
const sessionCleared = this.getSessionWasCleared?.() ?? false
|
||||
if (sessionCleared && !this.hasTextContent) {
|
||||
// Inject a "context cleared" message for the user
|
||||
const clearMessage = '✨ Context cleared. Starting fresh conversation.'
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: clearMessage
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: clearMessage
|
||||
})
|
||||
final.text = clearMessage
|
||||
}
|
||||
|
||||
const usage = {
|
||||
completion_tokens: chunk.totalUsage?.outputTokens || 0,
|
||||
prompt_tokens: chunk.totalUsage?.inputTokens || 0,
|
||||
total_tokens: chunk.totalUsage?.totalTokens || 0
|
||||
}
|
||||
const metrics = this.buildMetrics(chunk.totalUsage)
|
||||
const baseResponse = {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoningContent || ''
|
||||
}
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
...baseResponse,
|
||||
usage: { ...usage },
|
||||
metrics: metrics ? { ...metrics } : undefined
|
||||
}
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
...baseResponse,
|
||||
usage: { ...usage },
|
||||
metrics: metrics ? { ...metrics } : undefined
|
||||
}
|
||||
})
|
||||
this.resetTimingState()
|
||||
break
|
||||
}
|
||||
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
if (chunk.sourceType === 'url') {
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { sourceType: _, ...rest } = chunk
|
||||
final.webSearchResults.push(rest)
|
||||
}
|
||||
break
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'abort':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: new DOMException('Request was aborted', 'AbortError')
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: AISDKError.isInstance(chunk.error)
|
||||
? chunk.error
|
||||
: new ProviderSpecificError({
|
||||
message: formatErrorMessage(chunk.error),
|
||||
provider: 'unknown',
|
||||
cause: chunk.error
|
||||
})
|
||||
})
|
||||
break
|
||||
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
private buildMetrics(totalUsage?: {
|
||||
inputTokens?: number | null
|
||||
outputTokens?: number | null
|
||||
totalTokens?: number | null
|
||||
}) {
|
||||
if (!totalUsage) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const completionTokens = totalUsage.outputTokens ?? 0
|
||||
const now = Date.now()
|
||||
const start = this.responseStartTimestamp ?? now
|
||||
const firstToken = this.firstTokenTimestamp
|
||||
const timeFirstToken = Math.max(firstToken != null ? firstToken - start : 0, 0)
|
||||
const baseForCompletion = firstToken ?? start
|
||||
let timeCompletion = Math.max(now - baseForCompletion, 0)
|
||||
|
||||
if (timeCompletion === 0 && completionTokens > 0) {
|
||||
timeCompletion = 1
|
||||
}
|
||||
|
||||
return {
|
||||
completion_tokens: completionTokens,
|
||||
time_first_token_millsec: timeFirstToken,
|
||||
time_completion_millsec: timeCompletion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default AiSdkToChunkAdapter
|
||||
@@ -1,435 +0,0 @@
|
||||
/**
|
||||
* 工具调用 Chunk 处理模块
|
||||
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||
import type { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
export type ToolcallsMap = {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
// Streaming arguments buffer
|
||||
streamingArgs?: string
|
||||
}
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
|
||||
|
||||
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
/**
|
||||
* 内部静态方法:添加活跃工具调用的核心逻辑
|
||||
*/
|
||||
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
|
||||
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* 实例方法:添加活跃工具调用
|
||||
*/
|
||||
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取全局活跃的工具调用
|
||||
*/
|
||||
public static getActiveToolCalls() {
|
||||
return ToolCallChunkHandler.globalActiveToolCalls
|
||||
}
|
||||
|
||||
/**
|
||||
* 静态方法:添加活跃工具调用(外部访问)
|
||||
*/
|
||||
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据工具名称确定工具类型
|
||||
*/
|
||||
private determineToolType(toolName: string, toolCallId: string): BaseTool {
|
||||
let mcpTool: MCPTool | undefined
|
||||
if (toolName.startsWith('builtin_')) {
|
||||
return {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'builtin'
|
||||
} as BaseTool
|
||||
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
|
||||
return mcpTool
|
||||
} else {
|
||||
return {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具输入开始事件 - 流式参数开始
|
||||
*/
|
||||
public handleToolInputStart(chunk: {
|
||||
type: 'tool-input-start'
|
||||
id: string
|
||||
toolName: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
providerExecuted?: boolean
|
||||
}): void {
|
||||
const { id: toolCallId, toolName, providerExecuted } = chunk
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool-input-start chunk: missing id or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果已存在,跳过
|
||||
if (this.activeToolCalls.has(toolCallId)) {
|
||||
return
|
||||
}
|
||||
|
||||
let tool: BaseTool
|
||||
if (providerExecuted) {
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
} as BaseTool
|
||||
} else {
|
||||
tool = this.determineToolType(toolName, toolCallId)
|
||||
}
|
||||
|
||||
// 初始化流式工具调用
|
||||
this.addActiveToolCall(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args: undefined,
|
||||
tool,
|
||||
streamingArgs: ''
|
||||
})
|
||||
|
||||
logger.info(`🔧 [ToolCallChunkHandler] Tool input streaming started: ${toolName} (${toolCallId})`)
|
||||
|
||||
// 发送初始 streaming chunk
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: undefined,
|
||||
status: 'streaming',
|
||||
toolCallId: toolCallId,
|
||||
partialArguments: ''
|
||||
}
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_STREAMING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具输入增量事件 - 流式参数片段
|
||||
*/
|
||||
public handleToolInputDelta(chunk: {
|
||||
type: 'tool-input-delta'
|
||||
id: string
|
||||
delta: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}): void {
|
||||
const { id: toolCallId, delta } = chunk
|
||||
|
||||
const toolCall = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found for delta: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 累积流式参数
|
||||
toolCall.streamingArgs = (toolCall.streamingArgs || '') + delta
|
||||
|
||||
// 发送 streaming chunk 更新
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: toolCall.tool,
|
||||
arguments: undefined,
|
||||
status: 'streaming',
|
||||
toolCallId: toolCallId,
|
||||
partialArguments: toolCall.streamingArgs
|
||||
}
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_STREAMING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具输入结束事件 - 流式参数完成
|
||||
*/
|
||||
public handleToolInputEnd(chunk: { type: 'tool-input-end'; id: string; providerMetadata?: ProviderMetadata }): void {
|
||||
const { id: toolCallId } = chunk
|
||||
|
||||
const toolCall = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found for end: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析完整的 JSON 参数
|
||||
let parsedArgs: any = undefined
|
||||
if (toolCall.streamingArgs) {
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolCall.streamingArgs)
|
||||
toolCall.args = parsedArgs
|
||||
} catch (e) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Failed to parse streaming args for ${toolCallId}:`, e as Error)
|
||||
// 保留原始字符串
|
||||
toolCall.args = toolCall.streamingArgs
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`🔧 [ToolCallChunkHandler] Tool input streaming completed: ${toolCall.toolName} (${toolCallId})`)
|
||||
|
||||
// 发送 streaming 完成 chunk
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: toolCall.tool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending',
|
||||
toolCallId: toolCallId,
|
||||
partialArguments: toolCall.streamingArgs
|
||||
}
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_STREAMING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
public handleToolCall(
|
||||
chunk: {
|
||||
type: 'tool-call'
|
||||
} & TypedToolCall<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, toolName, input: args, providerExecuted } = chunk
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this tool call was already processed via streaming events
|
||||
const existingToolCall = this.activeToolCalls.get(toolCallId)
|
||||
if (existingToolCall?.streamingArgs !== undefined) {
|
||||
// Tool call was already processed via streaming events (tool-input-start/delta/end)
|
||||
// Update args if needed, but don't emit duplicate pending chunk
|
||||
existingToolCall.args = args
|
||||
return
|
||||
}
|
||||
|
||||
let tool: BaseTool
|
||||
let mcpTool: MCPTool | undefined
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
} as BaseTool
|
||||
} else if (toolName.startsWith('builtin_')) {
|
||||
// 如果是内置工具,沿用现有逻辑
|
||||
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'builtin'
|
||||
} as BaseTool
|
||||
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
|
||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||
// toolName is mcpTool.id (registered with id as key in convertMcpToolsToAiSdkTools)
|
||||
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
tool = mcpTool
|
||||
} else {
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
}
|
||||
}
|
||||
|
||||
this.addActiveToolCall(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
tool
|
||||
})
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: args,
|
||||
status: 'pending', // 统一使用 pending 状态
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用结果事件
|
||||
*/
|
||||
public handleToolResult(
|
||||
chunk: {
|
||||
type: 'tool-result'
|
||||
} & TypedToolResult<ToolSet>
|
||||
): void {
|
||||
// TODO: 基于AI SDK为供应商内置工具做更好的展示和类型安全处理
|
||||
const { toolCallId, output, input } = chunk
|
||||
|
||||
if (!toolCallId) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的工具调用信息
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallInfo.toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'done',
|
||||
response: output,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 工具特定的后处理
|
||||
switch (toolResponse.tool.name) {
|
||||
case 'builtin_knowledge_search': {
|
||||
processKnowledgeReferences(toolResponse.response, this.onChunk)
|
||||
break
|
||||
}
|
||||
// 未来可以在这里添加其他工具的后处理逻辑
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
|
||||
const images = extractImagesFromToolOutput(toolResponse.response)
|
||||
|
||||
if (images.length) {
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: images
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleToolError(
|
||||
chunk: {
|
||||
type: 'tool-error'
|
||||
} & TypedToolError<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, error, input } = chunk
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'error',
|
||||
response: error,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)
|
||||
|
||||
/**
|
||||
* 从工具输出中提取图片(使用 MCP SDK 类型安全验证)
|
||||
*/
|
||||
function extractImagesFromToolOutput(output: unknown): string[] {
|
||||
if (!output) {
|
||||
return []
|
||||
}
|
||||
|
||||
const result = CallToolResultSchema.safeParse(output)
|
||||
if (result.success) {
|
||||
return result.data.content
|
||||
.filter((c) => c.type === 'image')
|
||||
.map((content) => `data:${content.mimeType ?? 'image/png'};base64,${content.data}`)
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export { default as AiProvider, type AiProviderConfig } from './AiProvider'
|
||||
@@ -1,140 +0,0 @@
|
||||
import type { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, providerToolPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import { isGemini3Model, isQwen35to39Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
|
||||
import type { AiSdkMiddlewareConfig } from '../types/middlewareConfig'
|
||||
import { getReasoningTagName } from '../utils/reasoning'
|
||||
import { createAnthropicCachePlugin } from './anthropicCachePlugin'
|
||||
import { createNoThinkPlugin } from './noThinkPlugin'
|
||||
import { createOpenrouterReasoningPlugin } from './openrouterReasoningPlugin'
|
||||
import { createPdfCompatibilityPlugin } from './pdfCompatibilityPlugin'
|
||||
import { createQwenThinkingPlugin } from './qwenThinkingPlugin'
|
||||
import { createReasoningExtractionPlugin } from './reasoningExtractionPlugin'
|
||||
import { searchOrchestrationPlugin } from './searchOrchestrationPlugin'
|
||||
import { createSimulateStreamingPlugin } from './simulateStreamingPlugin'
|
||||
import { createSkipGeminiThoughtSignaturePlugin } from './skipGeminiThoughtSignaturePlugin'
|
||||
import { createTelemetryPlugin } from './telemetryPlugin'
|
||||
|
||||
const logger = loggerService.withContext('PluginBuilder')
|
||||
|
||||
/**
|
||||
* 构建插件的上下文参数
|
||||
*
|
||||
* provider 和 model 是必选的 — 由 AiProvider 内部注入,
|
||||
* 不再依赖调用方手动传入,从根本上避免遗漏。
|
||||
*/
|
||||
export interface BuildPluginsContext {
|
||||
provider: Provider
|
||||
model: Model
|
||||
config: AiSdkMiddlewareConfig & { assistant: Assistant; topicId?: string }
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据条件构建插件数组
|
||||
*/
|
||||
export async function buildPlugins({ provider, model, config }: BuildPluginsContext): Promise<AiPlugin[]> {
|
||||
const plugins: AiPlugin<any, any>[] = []
|
||||
|
||||
if (config.topicId && (await preferenceService.get('app.developer_mode.enabled'))) {
|
||||
// 0. 添加 telemetry 插件
|
||||
plugins.push(
|
||||
createTelemetryPlugin({
|
||||
enabled: true,
|
||||
topicId: config.topicId,
|
||||
assistant: config.assistant
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// === PDF Compatibility ===
|
||||
// Must run before other plugins (e.g., Anthropic cache token estimation)
|
||||
// so that PDF FileParts are converted to TextParts for unsupported providers.
|
||||
plugins.push(createPdfCompatibilityPlugin(provider, model))
|
||||
|
||||
// === AI SDK Middleware Plugins ===
|
||||
// 注意:wrapLanguageModel 会 .reverse() middleware 数组,
|
||||
// 数组中靠前的 middleware 反转后变成最外层包装。
|
||||
// extractReasoning 必须在 simulateStreaming 之前推入,
|
||||
// 这样反转后 extractReasoning 在外层,其 wrapStream(状态机)
|
||||
// 能处理 simulateStreaming 生成的模拟流中的未闭合 <think> 标签。
|
||||
|
||||
// 0.1 Reasoning extraction for OpenAI/Azure providers
|
||||
const providerType = provider.type
|
||||
if (providerType === 'openai' || providerType === 'azure-openai') {
|
||||
const tagName = getReasoningTagName(model.id.toLowerCase())
|
||||
plugins.push(createReasoningExtractionPlugin({ tagName }))
|
||||
}
|
||||
|
||||
// 0.2 Simulate streaming for non-streaming requests (must be AFTER reasoning extraction in array)
|
||||
if (!config.streamOutput) {
|
||||
plugins.push(createSimulateStreamingPlugin())
|
||||
}
|
||||
|
||||
if (provider.anthropicCacheControl?.tokenThreshold) {
|
||||
plugins.push(createAnthropicCachePlugin(provider))
|
||||
}
|
||||
|
||||
// 0.3 OpenRouter reasoning redaction
|
||||
if (provider.id === SystemProviderIds.openrouter) {
|
||||
plugins.push(createOpenrouterReasoningPlugin())
|
||||
}
|
||||
|
||||
// 0.4 OVMS no-think for MCP tools
|
||||
if (provider.id === 'ovms' && config.mcpTools && config.mcpTools.length > 0) {
|
||||
plugins.push(createNoThinkPlugin())
|
||||
}
|
||||
|
||||
// 0.5 Qwen thinking control for providers without enable_thinking support
|
||||
if (
|
||||
!isOllamaProvider(provider) &&
|
||||
isSupportedThinkingTokenQwenModel(model) &&
|
||||
!isQwen35to39Model(model) &&
|
||||
!isSupportEnableThinkingProvider(provider)
|
||||
) {
|
||||
const enableThinking = config.assistant?.settings?.reasoning_effort !== undefined
|
||||
plugins.push(createQwenThinkingPlugin(enableThinking))
|
||||
}
|
||||
|
||||
// 0.6 Skip Gemini3 thought signature for OpenAI-compatible API
|
||||
if (isGemini3Model(model)) {
|
||||
plugins.push(createSkipGeminiThoughtSignaturePlugin())
|
||||
}
|
||||
|
||||
// 1. Provider 工具注入 — providerToolPlugin 自动按 provider 分发工具
|
||||
if (config.enableWebSearch && config.webSearchPluginConfig) {
|
||||
plugins.push(providerToolPlugin('webSearch', config.webSearchPluginConfig))
|
||||
}
|
||||
if (config.enableUrlContext) {
|
||||
plugins.push(providerToolPlugin('urlContext', config.urlContextConfig))
|
||||
}
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (config.isSupportedToolUse || config.isPromptToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(config.assistant, config.topicId || ''))
|
||||
}
|
||||
|
||||
// 3. 推理模型时添加推理插件
|
||||
// if (config.enableReasoning) {
|
||||
// plugins.push(reasoningTimePlugin)
|
||||
// }
|
||||
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (config.isPromptToolUse) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
mcpMode: config.mcpMode
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
'Final plugin list:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
return plugins
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
import type { LanguageModelV3CallOptions } from '@ai-sdk/provider'
|
||||
import type { Model, Provider, ProviderType } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('i18next', () => ({
|
||||
default: { t: (key: string, opts?: Record<string, unknown>) => `${key}${opts ? JSON.stringify(opts) : ''}` }
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isAnthropicModel: vi.fn(() => false),
|
||||
isGeminiModel: vi.fn(() => false)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models/openai', () => ({
|
||||
isOpenAILLMModel: vi.fn(() => false)
|
||||
}))
|
||||
|
||||
const mockExtractPdfText = vi.fn()
|
||||
|
||||
vi.mock('@shared/utils/pdf', () => ({
|
||||
extractPdfText: (...args: unknown[]) => mockExtractPdfText(...args)
|
||||
}))
|
||||
|
||||
vi.stubGlobal('window', {
|
||||
...globalThis.window,
|
||||
api: {
|
||||
pdf: {
|
||||
extractText: mockExtractPdfText
|
||||
}
|
||||
},
|
||||
toast: {
|
||||
warning: vi.fn(),
|
||||
error: vi.fn()
|
||||
}
|
||||
})
|
||||
|
||||
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models'
|
||||
import { isOpenAILLMModel } from '@renderer/config/models/openai'
|
||||
|
||||
import { createPdfCompatibilityPlugin } from '../pdfCompatibilityPlugin'
|
||||
|
||||
function makeProvider(id: string, type: ProviderType): Provider {
|
||||
return { id, name: id, type, apiKey: 'test', apiHost: 'https://test.com', isSystem: false, models: [] } as Provider
|
||||
}
|
||||
|
||||
function makeModel(): Model {
|
||||
return { id: 'test-model', provider: 'test', name: 'Test', group: 'test' } as Model
|
||||
}
|
||||
|
||||
function makePdfFilePart(filename = 'test.pdf') {
|
||||
return {
|
||||
type: 'file' as const,
|
||||
data: 'base64pdfdata',
|
||||
mediaType: 'application/pdf',
|
||||
filename
|
||||
}
|
||||
}
|
||||
|
||||
function makeImageFilePart() {
|
||||
return {
|
||||
type: 'file' as const,
|
||||
data: 'base64imgdata',
|
||||
mediaType: 'image/png',
|
||||
filename: 'test.png'
|
||||
}
|
||||
}
|
||||
|
||||
function makeTextPart(text: string) {
|
||||
return { type: 'text' as const, text }
|
||||
}
|
||||
|
||||
async function runMiddleware(provider: Provider, params: LanguageModelV3CallOptions, model: Model = makeModel()) {
|
||||
const plugin = createPdfCompatibilityPlugin(provider, model)
|
||||
const context: {
|
||||
middlewares: Array<{ transformParams: (opts: Record<string, unknown>) => Promise<LanguageModelV3CallOptions> }>
|
||||
} = { middlewares: [] }
|
||||
void plugin.configureContext!(context as never)
|
||||
const middleware = context.middlewares[0]
|
||||
return middleware.transformParams({ params, type: 'generate', model: {} })
|
||||
}
|
||||
|
||||
describe('pdfCompatibilityPlugin', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(isOpenAILLMModel).mockReturnValue(false)
|
||||
vi.mocked(isAnthropicModel).mockReturnValue(false)
|
||||
vi.mocked(isGeminiModel).mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('should pass through for OpenAI model on any provider type', async () => {
|
||||
vi.mocked(isOpenAILLMModel).mockReturnValue(true)
|
||||
const provider = makeProvider('moonshot', 'openai')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result).toEqual(params)
|
||||
expect(mockExtractPdfText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass through for Claude model on any provider type', async () => {
|
||||
vi.mocked(isAnthropicModel).mockReturnValue(true)
|
||||
const provider = makeProvider('my-aggregator', 'new-api')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result).toEqual(params)
|
||||
expect(mockExtractPdfText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass through for Gemini model on any provider type', async () => {
|
||||
vi.mocked(isGeminiModel).mockReturnValue(true)
|
||||
const provider = makeProvider('my-aggregator', 'new-api')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result).toEqual(params)
|
||||
expect(mockExtractPdfText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass through unchanged when provider type supports native PDF (openai-response)', async () => {
|
||||
const provider = makeProvider('openai', 'openai-response')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart()] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result).toEqual(params)
|
||||
expect(mockExtractPdfText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should convert PDF for non-native provider types (new-api, gateway, openai)', async () => {
|
||||
const provider = makeProvider('moonshot', 'openai')
|
||||
mockExtractPdfText.mockResolvedValue('Extracted PDF content')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('report.pdf')] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(mockExtractPdfText).toHaveBeenCalledWith('base64pdfdata')
|
||||
expect(result.prompt[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: 'report.pdf\nExtracted PDF content' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('should convert PDF FilePart to TextPart for ollama provider', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
mockExtractPdfText.mockResolvedValue('Extracted PDF content')
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('report.pdf')] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(mockExtractPdfText).toHaveBeenCalledWith('base64pdfdata')
|
||||
expect(result.prompt[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: 'report.pdf\nExtracted PDF content' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('should drop PDF part and warn when text extraction fails', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
mockExtractPdfText.mockRejectedValue(new Error('parse failed'))
|
||||
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), makePdfFilePart('broken.pdf')] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result.prompt[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Hello' }]
|
||||
})
|
||||
expect(window.toast.warning).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not convert non-PDF FileParts', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
|
||||
const imagePart = makeImageFilePart()
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Hello'), imagePart] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result.prompt[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Hello' }, imagePart]
|
||||
})
|
||||
expect(mockExtractPdfText).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle mixed content: text + PDF + image — only PDF converted', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
mockExtractPdfText.mockResolvedValue('PDF text content')
|
||||
|
||||
const imagePart = makeImageFilePart()
|
||||
const params = {
|
||||
prompt: [{ role: 'user' as const, content: [makeTextPart('Analyze'), makePdfFilePart('doc.pdf'), imagePart] }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result.prompt[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Analyze' }, { type: 'text', text: 'doc.pdf\nPDF text content' }, imagePart]
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass through when prompt is empty', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
const params = { prompt: [] } as unknown as LanguageModelV3CallOptions
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result).toEqual(params)
|
||||
})
|
||||
|
||||
it('should pass through messages with string content (system messages)', async () => {
|
||||
const provider = makeProvider('ollama', 'ollama')
|
||||
const params = {
|
||||
prompt: [{ role: 'system' as const, content: 'You are a helpful assistant' }]
|
||||
} as unknown as LanguageModelV3CallOptions
|
||||
|
||||
const result = await runMiddleware(provider, params)
|
||||
expect(result.prompt[0]).toMatchObject({ role: 'system', content: 'You are a helpful assistant' })
|
||||
})
|
||||
})
|
||||
@@ -1,96 +0,0 @@
|
||||
/**
|
||||
* Anthropic Prompt Caching Middleware
|
||||
* @see https://ai-sdk.dev/providers/ai-sdk-providers/anthropic#cache-control
|
||||
*/
|
||||
import type { LanguageModelV3Message } from '@ai-sdk/provider'
|
||||
import { definePlugin } from '@cherrystudio/ai-core/core/plugins'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
const cacheProviderOptions = {
|
||||
anthropic: { cacheControl: { type: 'ephemeral' } }
|
||||
}
|
||||
|
||||
function estimateContentTokens(content: LanguageModelV3Message['content']): number {
|
||||
if (typeof content === 'string') return estimateTextTokens(content)
|
||||
if (Array.isArray(content)) {
|
||||
return content.reduce((acc, part) => {
|
||||
if (part.type === 'text') {
|
||||
return acc + estimateTextTokens(part.text)
|
||||
}
|
||||
return acc
|
||||
}, 0)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
function anthropicCacheMiddleware(provider: Provider): LanguageModelMiddleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
transformParams: async ({ params }) => {
|
||||
const settings = provider.anthropicCacheControl
|
||||
if (!settings?.tokenThreshold || !Array.isArray(params.prompt) || params.prompt.length === 0) {
|
||||
return params
|
||||
}
|
||||
|
||||
const { tokenThreshold, cacheSystemMessage, cacheLastNMessages } = settings
|
||||
const messages = [...params.prompt]
|
||||
let cachedCount = 0
|
||||
|
||||
// Cache system message (providerOptions on message object)
|
||||
if (cacheSystemMessage) {
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
if (msg.role === 'system' && estimateContentTokens(msg.content) >= tokenThreshold) {
|
||||
messages[i] = { ...msg, providerOptions: cacheProviderOptions }
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache last N non-system messages (providerOptions on content parts)
|
||||
if (cacheLastNMessages > 0) {
|
||||
const cumsumTokens = [] as Array<number>
|
||||
let tokenSum = 0 as number
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
tokenSum += estimateContentTokens(msg.content)
|
||||
cumsumTokens.push(tokenSum)
|
||||
}
|
||||
|
||||
for (let i = messages.length - 1; i >= 0 && cachedCount < cacheLastNMessages; i--) {
|
||||
const msg = messages[i]
|
||||
if (msg.role === 'system' || cumsumTokens[i] < tokenThreshold || msg.content.length === 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
const newContent = [...msg.content]
|
||||
const lastIndex = newContent.length - 1
|
||||
newContent[lastIndex] = {
|
||||
...newContent[lastIndex],
|
||||
providerOptions: cacheProviderOptions
|
||||
}
|
||||
|
||||
messages[i] = {
|
||||
...msg,
|
||||
content: newContent
|
||||
} as LanguageModelV3Message
|
||||
cachedCount++
|
||||
}
|
||||
}
|
||||
|
||||
return { ...params, prompt: messages }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createAnthropicCachePlugin = (provider: Provider) =>
|
||||
definePlugin({
|
||||
name: 'anthropicCache',
|
||||
enforce: 'pre',
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(anthropicCacheMiddleware(provider))
|
||||
}
|
||||
})
|
||||
@@ -1,64 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('noThinkPlugin')
|
||||
|
||||
/**
|
||||
* No Think Middleware
|
||||
* Automatically appends ' /no_think' string to the end of user messages for the provider
|
||||
* This prevents the model from generating unnecessary thinking process and returns results directly
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
function createNoThinkMiddleware(): LanguageModelMiddleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
// Only process user messages
|
||||
if (message.role === 'user') {
|
||||
// Process content array
|
||||
if (Array.isArray(message.content)) {
|
||||
const lastContent = message.content[message.content.length - 1]
|
||||
// If the last content is text type, append ' /no_think'
|
||||
if (lastContent && lastContent.type === 'text' && typeof lastContent.text === 'string') {
|
||||
// Avoid duplicate additions
|
||||
if (!lastContent.text.endsWith('/no_think')) {
|
||||
logger.debug('Adding /no_think to user message')
|
||||
return {
|
||||
...message,
|
||||
content: [
|
||||
...message.content.slice(0, -1),
|
||||
{
|
||||
...lastContent,
|
||||
text: lastContent.text + ' /no_think'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createNoThinkPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'noThink',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createNoThinkMiddleware())
|
||||
}
|
||||
})
|
||||
@@ -1,62 +0,0 @@
|
||||
import type { LanguageModelV3StreamPart } from '@ai-sdk/provider'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
||||
*/
|
||||
function createOpenrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
const REDACTED_BLOCK = '[REDACTED]'
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
wrapGenerate: async ({ doGenerate }) => {
|
||||
const { content, ...rest } = await doGenerate()
|
||||
const modifiedContent = content.map((part) => {
|
||||
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||
return {
|
||||
...part,
|
||||
text: part.text.replace(REDACTED_BLOCK, '')
|
||||
}
|
||||
}
|
||||
return part
|
||||
})
|
||||
return { content: modifiedContent, ...rest }
|
||||
},
|
||||
wrapStream: async ({ doStream }) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<LanguageModelV3StreamPart, LanguageModelV3StreamPart>({
|
||||
transform(
|
||||
chunk: LanguageModelV3StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV3StreamPart>
|
||||
) {
|
||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||
})
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createOpenrouterReasoningPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'openrouterReasoning',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createOpenrouterReasoningMiddleware())
|
||||
}
|
||||
})
|
||||
@@ -1,114 +0,0 @@
|
||||
/**
|
||||
* PDF Compatibility Plugin
|
||||
*
|
||||
* Converts PDF FileParts to TextParts for providers that don't support native PDF input.
|
||||
* Extracts text directly from the FilePart's base64 data using pdf-parse.
|
||||
*/
|
||||
import type { LanguageModelV3FilePart, LanguageModelV3Message } from '@ai-sdk/provider'
|
||||
import { definePlugin } from '@cherrystudio/ai-core/core/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models'
|
||||
import { isOpenAILLMModel } from '@renderer/config/models/openai'
|
||||
import type { Model, Provider, ProviderType } from '@renderer/types'
|
||||
import { extractPdfText } from '@shared/utils/pdf'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
import i18n from 'i18next'
|
||||
|
||||
const logger = loggerService.withContext('pdfCompatibilityPlugin')
|
||||
|
||||
type ContentPart = Exclude<LanguageModelV3Message['content'], string>[number]
|
||||
|
||||
/**
|
||||
* Provider types whose API natively supports PDF file input.
|
||||
* Only first-party provider protocols (OpenAI, Anthropic, Google) are included.
|
||||
* Aggregators (new-api, gateway) and generic 'openai' type are excluded
|
||||
* because they may route to backends that don't support the 'file' part type.
|
||||
*/
|
||||
const PDF_NATIVE_PROVIDER_TYPES = new Set<ProviderType>([
|
||||
'openai-response', // OpenAI Responses API
|
||||
'anthropic', // Anthropic API
|
||||
'gemini', // Google Gemini API
|
||||
'azure-openai', // Azure OpenAI
|
||||
'vertexai', // Google Vertex AI
|
||||
'aws-bedrock', // AWS Bedrock
|
||||
'vertex-anthropic' // Vertex AI with Anthropic models
|
||||
])
|
||||
|
||||
function isPdfFilePart(part: ContentPart): part is LanguageModelV3FilePart & { mediaType: 'application/pdf' } {
|
||||
return part.type === 'file' && part.mediaType === 'application/pdf'
|
||||
}
|
||||
|
||||
function supportsNativePdf(provider: Provider, model: Model): boolean {
|
||||
// OpenAI, Claude, and Gemini models always support native PDF regardless of provider
|
||||
if (isOpenAILLMModel(model) || isAnthropicModel(model) || isGeminiModel(model)) {
|
||||
return true
|
||||
}
|
||||
if (PDF_NATIVE_PROVIDER_TYPES.has(provider.type)) {
|
||||
return true
|
||||
}
|
||||
// TODO: allow user to configure native pdf compatibility for provider/model
|
||||
return false
|
||||
}
|
||||
|
||||
function pdfCompatibilityMiddleware(provider: Provider, model: Model): LanguageModelMiddleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
transformParams: async ({ params }) => {
|
||||
if (supportsNativePdf(provider, model)) {
|
||||
return params
|
||||
}
|
||||
|
||||
if (!Array.isArray(params.prompt) || params.prompt.length === 0) {
|
||||
return params
|
||||
}
|
||||
|
||||
const messages: LanguageModelV3Message[] = []
|
||||
for (const message of params.prompt) {
|
||||
if (!Array.isArray(message.content)) {
|
||||
messages.push(message)
|
||||
continue
|
||||
}
|
||||
|
||||
const hasPdf = message.content.some((part: (typeof message.content)[number]) => isPdfFilePart(part))
|
||||
if (!hasPdf) {
|
||||
messages.push(message)
|
||||
continue
|
||||
}
|
||||
|
||||
const newContent: ContentPart[] = []
|
||||
for (const part of message.content) {
|
||||
if (!isPdfFilePart(part)) {
|
||||
newContent.push(part)
|
||||
continue
|
||||
}
|
||||
|
||||
const fileName = part.filename || 'PDF'
|
||||
|
||||
try {
|
||||
const textContent =
|
||||
part.data instanceof URL ? await extractPdfText(part.data) : await window.api.pdf.extractText(part.data)
|
||||
logger.debug(`Converting PDF FilePart to TextPart for provider ${provider.id} (type: ${provider.type})`)
|
||||
newContent.push({ type: 'text', text: `${fileName}\n${textContent.trim()}` })
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to extract text from PDF ${fileName}:`, error instanceof Error ? error : undefined)
|
||||
window.toast.warning(i18n.t('message.warning.file.pdf_text_extraction_failed', { name: fileName }))
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(Object.assign({}, message, { content: newContent }))
|
||||
}
|
||||
|
||||
return { ...params, prompt: messages }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createPdfCompatibilityPlugin = (provider: Provider, model: Model) =>
|
||||
definePlugin({
|
||||
name: 'pdfCompatibility',
|
||||
enforce: 'pre',
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(pdfCompatibilityMiddleware(provider, model))
|
||||
}
|
||||
})
|
||||
@@ -1,54 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Qwen Thinking Middleware
|
||||
* Controls thinking mode for Qwen models on providers that don't support enable_thinking parameter (like Ollama)
|
||||
* Appends '/think' or '/no_think' suffix to user messages based on reasoning_effort setting
|
||||
*
|
||||
* NOTE: Qwen3.5 does not officially support the soft switch of Qwen3, i.e., /think and /nothink.
|
||||
*
|
||||
* @param enableThinking - Whether thinking mode is enabled (based on reasoning_effort !== undefined)
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
function createQwenThinkingMiddleware(enableThinking: boolean): LanguageModelMiddleware {
|
||||
const suffix = enableThinking ? ' /think' : ' /no_think'
|
||||
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
// Only process user messages
|
||||
if (message.role === 'user') {
|
||||
// Process content array
|
||||
if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
if (part.type === 'text' && !part.text.endsWith('/think') && !part.text.endsWith('/no_think')) {
|
||||
part.text += suffix
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createQwenThinkingPlugin = (enableThinking: boolean) =>
|
||||
definePlugin({
|
||||
name: 'qwenThinking',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createQwenThinkingMiddleware(enableThinking))
|
||||
}
|
||||
})
|
||||
@@ -1,22 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { extractReasoningMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Reasoning Extraction Plugin
|
||||
* Extracts reasoning/thinking tags from OpenAI/Azure responses
|
||||
* Uses AI SDK's built-in extractReasoningMiddleware
|
||||
*/
|
||||
export const createReasoningExtractionPlugin = (options: { tagName?: string } = {}) =>
|
||||
definePlugin({
|
||||
name: 'reasoningExtraction',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(
|
||||
extractReasoningMiddleware({
|
||||
tagName: options.tagName || 'thinking'
|
||||
})
|
||||
)
|
||||
}
|
||||
})
|
||||
@@ -1,37 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
export default definePlugin({
|
||||
name: 'reasoningTimePlugin',
|
||||
|
||||
transformStream: () => () => {
|
||||
// === 时间跟踪状态 ===
|
||||
let thinkingStartTime = 0
|
||||
let accumulatedThinkingContent = ''
|
||||
|
||||
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
|
||||
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
|
||||
// === 处理 reasoning 类型 ===
|
||||
if (chunk.type === 'reasoning-start') {
|
||||
controller.enqueue(chunk)
|
||||
thinkingStartTime = performance.now()
|
||||
} else if (chunk.type === 'reasoning-delta') {
|
||||
accumulatedThinkingContent += chunk.text
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
providerMetadata: {
|
||||
...chunk.providerMetadata,
|
||||
metadata: {
|
||||
...chunk.providerMetadata?.metadata,
|
||||
thinking_millsec: performance.now() - thinkingStartTime,
|
||||
thinking_content: accumulatedThinkingContent
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -1,407 +0,0 @@
|
||||
/**
|
||||
* 搜索编排插件
|
||||
*
|
||||
* 功能:
|
||||
* 1. onRequestStart: 智能意图识别 - 分析是否需要网络搜索、知识库搜索、记忆搜索
|
||||
* 2. transformParams: 根据意图分析结果动态添加对应的工具
|
||||
* 3. onRequestEnd: 自动记忆存储
|
||||
*/
|
||||
import {
|
||||
type AiPlugin,
|
||||
type AiRequestContext,
|
||||
definePlugin,
|
||||
type StreamTextParams,
|
||||
type StreamTextResult
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { loggerService } from '@logger'
|
||||
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import type { ExtractResults } from '@renderer/utils/extract'
|
||||
import { extractInfoFromXML } from '@renderer/utils/extract'
|
||||
// import { generateObject } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@shared/config/prompts'
|
||||
import type { LanguageModel, ModelMessage } from 'ai'
|
||||
import { generateText } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||
|
||||
const logger = loggerService.withContext('SearchOrchestrationPlugin')
|
||||
|
||||
export const getMessageContent = (message: ModelMessage) => {
|
||||
if (typeof message.content === 'string') return message.content
|
||||
return message.content.reduce((acc, part) => {
|
||||
if (part.type === 'text') {
|
||||
return acc + part.text + '\n'
|
||||
}
|
||||
return acc
|
||||
}, '')
|
||||
}
|
||||
|
||||
// === Schema Definitions ===
|
||||
|
||||
// const WebSearchSchema = z.object({
|
||||
// question: z
|
||||
// .array(z.string())
|
||||
// .describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||
// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||
// })
|
||||
|
||||
// const KnowledgeSearchSchema = z.object({
|
||||
// question: z
|
||||
// .array(z.string())
|
||||
// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||
// rewrite: z
|
||||
// .string()
|
||||
// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||
// })
|
||||
|
||||
// const SearchIntentAnalysisSchema = z.object({
|
||||
// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||
// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||
// })
|
||||
|
||||
// type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||
|
||||
// let isAnalyzing = false
|
||||
/**
|
||||
* 🧠 意图分析函数 - 使用 XML 解析
|
||||
*/
|
||||
async function analyzeSearchIntent(
|
||||
lastUserMessage: ModelMessage,
|
||||
assistant: Assistant,
|
||||
options: {
|
||||
shouldWebSearch?: boolean
|
||||
shouldKnowledgeSearch?: boolean
|
||||
shouldMemorySearch?: boolean
|
||||
lastAnswer?: ModelMessage
|
||||
context: AiRequestContext
|
||||
topicId: string
|
||||
}
|
||||
): Promise<ExtractResults | undefined> {
|
||||
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
||||
|
||||
if (!lastUserMessage) return undefined
|
||||
|
||||
// 根据配置决定是否需要提取
|
||||
const needWebExtract = shouldWebSearch
|
||||
const needKnowledgeExtract = shouldKnowledgeSearch
|
||||
|
||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// 选择合适的提示词
|
||||
let prompt: string
|
||||
// let schema: z.Schema
|
||||
|
||||
if (needWebExtract && !needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
// schema = z.object({ websearch: WebSearchSchema })
|
||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
// schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||
} else {
|
||||
prompt = SEARCH_SUMMARY_PROMPT
|
||||
// schema = SearchIntentAnalysisSchema
|
||||
}
|
||||
|
||||
// 构建消息上下文 - 简化逻辑
|
||||
const chatHistory = lastAnswer ? `assistant: ${getMessageContent(lastAnswer)}` : ''
|
||||
const question = getMessageContent(lastUserMessage) || ''
|
||||
|
||||
// 使用模板替换变量
|
||||
const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
|
||||
|
||||
// 获取模型和provider信息
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!provider || isEmpty(provider.apiKey)) {
|
||||
logger.error('Provider not found or missing API key')
|
||||
return getFallbackResult()
|
||||
}
|
||||
try {
|
||||
logger.info('Starting intent analysis generateText call', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId,
|
||||
hasWebSearch: needWebExtract,
|
||||
hasKnowledgeSearch: needKnowledgeExtract
|
||||
})
|
||||
|
||||
const { text: result } = await generateText({
|
||||
model: context.model as LanguageModel,
|
||||
prompt: formattedPrompt
|
||||
}).finally(() => {
|
||||
logger.info('Intent analysis generateText call completed', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId
|
||||
})
|
||||
})
|
||||
const parsedResult = extractInfoFromXML(result)
|
||||
logger.debug('Intent analysis result', { parsedResult })
|
||||
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? parsedResult?.websearch : undefined,
|
||||
knowledge: needKnowledgeExtract ? parsedResult?.knowledge : undefined
|
||||
}
|
||||
} catch (e: any) {
|
||||
logger.error('Intent analysis failed', e as Error)
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
function getFallbackResult(): ExtractResults {
|
||||
const fallbackContent = getMessageContent(lastUserMessage)
|
||||
return {
|
||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
knowledge: shouldKnowledgeSearch
|
||||
? {
|
||||
question: [fallbackContent || 'search'],
|
||||
rewrite: fallbackContent || 'search'
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
||||
*/
|
||||
async function storeConversationMemory(
|
||||
messages: ModelMessage[],
|
||||
assistant: Assistant,
|
||||
context: AiRequestContext
|
||||
): Promise<void> {
|
||||
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
|
||||
// 转换消息为记忆处理器期望的格式
|
||||
const conversationMessages = messages
|
||||
.filter((msg) => msg.role === 'user' || msg.role === 'assistant')
|
||||
.map((msg) => ({
|
||||
role: msg.role,
|
||||
content: getMessageContent(msg) || ''
|
||||
}))
|
||||
.filter((msg) => msg.content.trim().length > 0)
|
||||
logger.debug('conversationMessages', conversationMessages)
|
||||
if (conversationMessages.length < 2) {
|
||||
logger.info('Need at least a user message and assistant response for memory processing')
|
||||
return
|
||||
}
|
||||
|
||||
const currentUserId = await preferenceService.get('feature.memory.current_user_id')
|
||||
// const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(
|
||||
memoryConfig,
|
||||
assistant.id,
|
||||
currentUserId,
|
||||
context.requestId
|
||||
)
|
||||
|
||||
logger.info('Processing conversation memory...', { messageCount: conversationMessages.length })
|
||||
|
||||
// 后台处理对话记忆(不阻塞 UI)
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
memoryProcessor
|
||||
.processConversation(conversationMessages, processorConfig)
|
||||
.then((result) => {
|
||||
logger.info('Memory processing completed:', result)
|
||||
if (result.facts?.length > 0) {
|
||||
logger.info('Extracted facts from conversation:', result.facts)
|
||||
logger.info('Memory operations performed:', result.operations)
|
||||
} else {
|
||||
logger.info('No facts extracted from conversation')
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error('Background memory processing failed:', error as Error)
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in conversation memory processing:', error as Error)
|
||||
// 不抛出错误,避免影响主流程
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 🎯 搜索编排插件
|
||||
*/
|
||||
export const searchOrchestrationPlugin = (
|
||||
assistant: Assistant,
|
||||
topicId: string
|
||||
): AiPlugin<StreamTextParams, StreamTextResult> => {
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||
|
||||
return definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
/**
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context) => {
|
||||
// 没开启任何搜索则不进行意图分析
|
||||
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
||||
|
||||
try {
|
||||
const messages = context.originalParams.messages
|
||||
if (!messages || messages.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const lastUserMessage = messages[messages.length - 1]
|
||||
const lastAssistantMessage = messages.length >= 2 ? messages[messages.length - 2] : undefined
|
||||
|
||||
// 存储用户消息用于后续记忆存储
|
||||
userMessages[context.requestId] = lastUserMessage
|
||||
|
||||
// 判断是否需要各种搜索
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
||||
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||
|
||||
// 执行意图分析
|
||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
shouldMemorySearch,
|
||||
lastAnswer: lastAssistantMessage,
|
||||
context,
|
||||
topicId
|
||||
})
|
||||
|
||||
if (analysisResult) {
|
||||
intentAnalysisResults[context.requestId] = analysisResult
|
||||
// logger.info('🧠 Intent analysis completed:', analysisResult)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('🧠 Intent analysis failed:', error as Error)
|
||||
// 不抛出错误,让流程继续
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params, context) => {
|
||||
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
const analysisResult = intentAnalysisResults[context.requestId]
|
||||
// if (!analysisResult || !assistant) {
|
||||
// logger.info('🔧 No analysis result or assistant, skipping tool configuration')
|
||||
// return params
|
||||
// }
|
||||
|
||||
// 确保 tools 对象存在
|
||||
if (!params.tools) {
|
||||
params.tools = {}
|
||||
}
|
||||
|
||||
// 🌐 网络搜索工具配置
|
||||
if (analysisResult?.websearch && assistant.webSearchProviderId) {
|
||||
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||
|
||||
if (needsSearch) {
|
||||
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||
// logger.info('🌐 Adding web search tool with pre-extracted keywords')
|
||||
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
|
||||
assistant.webSearchProviderId,
|
||||
analysisResult.websearch,
|
||||
context.requestId
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 📚 知识库搜索工具配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
|
||||
if (shouldKnowledgeSearch) {
|
||||
// on 模式:根据意图识别结果决定是否添加工具
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult?.knowledge &&
|
||||
analysisResult.knowledge.question &&
|
||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||
// logger.info('📚 Adding knowledge search tool (intent-based)')
|
||||
const userMessage = userMessages[context.requestId]
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||
assistant,
|
||||
analysisResult.knowledge,
|
||||
topicId,
|
||||
getMessageContent(userMessage)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 🧠 记忆搜索工具配置
|
||||
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
|
||||
if (globalMemoryEnabled && assistant.enableMemory) {
|
||||
// logger.info('🧠 Adding memory search tool')
|
||||
params.tools['builtin_memory_search'] = memorySearchTool(assistant.id)
|
||||
}
|
||||
|
||||
// logger.info('🔧 Tools configured:', Object.keys(params.tools))
|
||||
return params
|
||||
} catch (error) {
|
||||
logger.error('🔧 Tool configuration failed:', error as Error)
|
||||
return params
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 💾 Step 3: 记忆存储阶段
|
||||
*/
|
||||
|
||||
onRequestEnd: async (context) => {
|
||||
// context.isAnalyzing = false
|
||||
// logger.info('context.isAnalyzing', context, result)
|
||||
// logger.info('💾 Starting memory storage...', context.requestId)
|
||||
try {
|
||||
// ✅ 类型安全访问:context.originalParams 已通过泛型正确类型化
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (messages && assistant) {
|
||||
await storeConversationMemory(messages, assistant, context)
|
||||
}
|
||||
|
||||
// 清理缓存
|
||||
delete intentAnalysisResults[context.requestId]
|
||||
delete userMessages[context.requestId]
|
||||
} catch (error) {
|
||||
logger.error('💾 Memory storage failed:', error as Error)
|
||||
// 不抛出错误,避免影响主流程
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default searchOrchestrationPlugin
|
||||
@@ -1,18 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Simulate Streaming Plugin
|
||||
* Converts non-streaming responses to streaming format
|
||||
* Uses AI SDK's built-in simulateStreamingMiddleware
|
||||
*/
|
||||
export const createSimulateStreamingPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'simulateStreaming',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(simulateStreamingMiddleware())
|
||||
}
|
||||
})
|
||||
@@ -1,73 +0,0 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('skipGeminiThoughtSignaturePlugin')
|
||||
|
||||
/**
|
||||
* skip Gemini Thought Signature Middleware
|
||||
*
|
||||
* Handles:
|
||||
* - Tool-call parts need thought_signature for OpenAI-compatible API
|
||||
* -> Add providerOptions.openaiCompatible.extra_content.google.thought_signature
|
||||
*
|
||||
* Note: Thought signature for text/reasoning parts is now handled in messageConverter.
|
||||
*
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
function createSkipGeminiThoughtSignatureMiddleware(): LanguageModelMiddleware {
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
logger.debug('transformedParams', transformedParams)
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
for (const part of message.content) {
|
||||
const isToolCallPart = part.type === 'tool-call'
|
||||
|
||||
// Note: text part and reasoning part do not require thought signature validation
|
||||
// They are handled by messageConverter now
|
||||
|
||||
// Case: OpenAI-compatible path - add extra_content for tool-call parts
|
||||
// All tool-calls need the signature for Gemini OpenAI-compatible API
|
||||
if (isToolCallPart) {
|
||||
if (!part.providerOptions) {
|
||||
part.providerOptions = {}
|
||||
}
|
||||
if (!part.providerOptions.openaiCompatible) {
|
||||
part.providerOptions.openaiCompatible = {}
|
||||
}
|
||||
// Google OpenAI-compatible API expects extra_content.google.thought_signature
|
||||
// See: https://ai.google.dev/gemini-api/docs/thought-signatures#openai
|
||||
part.providerOptions.openaiCompatible.extra_content = {
|
||||
google: {
|
||||
thought_signature: MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const createSkipGeminiThoughtSignaturePlugin = () =>
|
||||
definePlugin({
|
||||
name: 'skipGeminiThoughtSignature',
|
||||
enforce: 'pre',
|
||||
|
||||
configureContext: (context) => {
|
||||
context.middlewares = context.middlewares || []
|
||||
context.middlewares.push(createSkipGeminiThoughtSignatureMiddleware())
|
||||
}
|
||||
})
|
||||
@@ -1,442 +0,0 @@
|
||||
/**
|
||||
* Telemetry Plugin for AI SDK Integration
|
||||
*
|
||||
* 在 transformParams 钩子中注入 experimental_telemetry 参数,
|
||||
* 实现 AI SDK trace 与现有手动 trace 系统的统一
|
||||
* 集成 AiSdkSpanAdapter 将 AI SDK trace 数据转换为现有格式
|
||||
*/
|
||||
|
||||
import type { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { definePlugin, type StreamTextParams, type StreamTextResult } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import type { Context, Span, SpanContext, Tracer } from '@opentelemetry/api'
|
||||
import { context as otelContext, trace } from '@opentelemetry/api'
|
||||
import { currentSpan } from '@renderer/services/SpanManagerService'
|
||||
import { webTraceService } from '@renderer/services/WebTraceService'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import type { TelemetrySettings } from 'ai'
|
||||
|
||||
import { AiSdkSpanAdapter } from '../trace/AiSdkSpanAdapter'
|
||||
|
||||
const logger = loggerService.withContext('TelemetryPlugin')
|
||||
|
||||
export interface TelemetryPluginConfig {
|
||||
enabled?: boolean
|
||||
recordInputs?: boolean
|
||||
recordOutputs?: boolean
|
||||
topicId: string
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
/**
|
||||
* 自定义 Tracer,集成适配器转换逻辑
|
||||
*/
|
||||
class AdapterTracer {
|
||||
private originalTracer: Tracer
|
||||
private topicId?: string
|
||||
private modelName?: string
|
||||
private parentSpanContext?: SpanContext
|
||||
private cachedParentContext?: Context
|
||||
|
||||
constructor(originalTracer: Tracer, topicId?: string, modelName?: string, parentSpanContext?: SpanContext) {
|
||||
this.originalTracer = originalTracer
|
||||
this.topicId = topicId
|
||||
this.modelName = modelName
|
||||
this.parentSpanContext = parentSpanContext
|
||||
// 预构建一个包含父 SpanContext 的 Context,便于复用
|
||||
try {
|
||||
this.cachedParentContext = this.parentSpanContext
|
||||
? trace.setSpanContext(otelContext.active(), this.parentSpanContext)
|
||||
: undefined
|
||||
} catch {
|
||||
this.cachedParentContext = undefined
|
||||
}
|
||||
|
||||
logger.debug('AdapterTracer created with parent context info', {
|
||||
topicId,
|
||||
modelName,
|
||||
parentTraceId: this.parentSpanContext?.traceId,
|
||||
parentSpanId: this.parentSpanContext?.spanId,
|
||||
hasOriginalTracer: !!originalTracer
|
||||
})
|
||||
}
|
||||
|
||||
startSpan(name: string, options?: any, context?: any): Span {
|
||||
logger.debug('AdapterTracer.startSpan called', {
|
||||
spanName: name,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
|
||||
// 创建包含父 SpanContext 的上下文(如果有的话)
|
||||
const createContextWithParent = () => {
|
||||
if (this.cachedParentContext) {
|
||||
return this.cachedParentContext
|
||||
}
|
||||
if (this.parentSpanContext) {
|
||||
try {
|
||||
const ctx = trace.setSpanContext(otelContext.active(), this.parentSpanContext)
|
||||
logger.debug('Created active context with parent SpanContext for startSpan', {
|
||||
spanName: name,
|
||||
parentTraceId: this.parentSpanContext.traceId,
|
||||
parentSpanId: this.parentSpanContext.spanId,
|
||||
topicId: this.topicId
|
||||
})
|
||||
return ctx
|
||||
} catch (error) {
|
||||
logger.warn('Failed to create context with parent SpanContext in startSpan', error as Error)
|
||||
}
|
||||
}
|
||||
return otelContext.active()
|
||||
}
|
||||
|
||||
const ctx = context ?? createContextWithParent()
|
||||
const span = this.originalTracer.startSpan(name, options, ctx)
|
||||
|
||||
// 注入父子关系属性(兜底重建层级用)
|
||||
try {
|
||||
if (this.parentSpanContext) {
|
||||
span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
|
||||
span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
|
||||
}
|
||||
if (this.topicId) {
|
||||
span.setAttribute('trace.topicId', this.topicId)
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Failed to set trace parent attributes in startSpan', e as Error)
|
||||
}
|
||||
|
||||
// 包装span的end方法
|
||||
const originalEnd = span.end.bind(span)
|
||||
span.end = (endTime?: any) => {
|
||||
logger.debug('AI SDK span.end() called in startSpan - about to convert span', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
|
||||
// 调用原始 end 方法
|
||||
originalEnd(endTime)
|
||||
|
||||
// 转换并保存 span 数据
|
||||
try {
|
||||
logger.debug('Converting AI SDK span to SpanEntity (from startSpan)', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
logger.silly('span', span)
|
||||
const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
|
||||
span,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
|
||||
// 保存转换后的数据
|
||||
void window.api.trace.saveEntity(spanEntity)
|
||||
|
||||
logger.debug('AI SDK span converted and saved successfully (from startSpan)', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName,
|
||||
hasUsage: !!spanEntity.usage,
|
||||
usage: spanEntity.usage
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to convert AI SDK span (from startSpan)', error as Error, {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return span
|
||||
}
|
||||
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, context: any, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, arg2?: any, arg3?: any, arg4?: any): ReturnType<F> {
|
||||
logger.debug('AdapterTracer.startActiveSpan called', {
|
||||
spanName: name,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName,
|
||||
// oxlint-disable-next-line no-undef False alarm. see https://github.com/oxc-project/oxc/issues/4232
|
||||
argCount: arguments.length
|
||||
})
|
||||
|
||||
// 包装函数来添加span转换逻辑
|
||||
const wrapFunction = (originalFn: F, span: Span): F => {
|
||||
const wrappedFn = ((passedSpan: Span) => {
|
||||
// 注入父子关系属性(兜底重建层级用)
|
||||
try {
|
||||
if (this.parentSpanContext) {
|
||||
passedSpan.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
|
||||
passedSpan.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
|
||||
}
|
||||
if (this.topicId) {
|
||||
passedSpan.setAttribute('trace.topicId', this.topicId)
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Failed to set trace parent attributes in startActiveSpan', e as Error)
|
||||
}
|
||||
// 包装span的end方法
|
||||
const originalEnd = span.end.bind(span)
|
||||
span.end = (endTime?: any) => {
|
||||
logger.debug('AI SDK span.end() called in startActiveSpan - about to convert span', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
|
||||
// 调用原始 end 方法
|
||||
originalEnd(endTime)
|
||||
|
||||
// 转换并保存 span 数据
|
||||
try {
|
||||
logger.debug('Converting AI SDK span to SpanEntity (from startActiveSpan)', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
logger.silly('span', span)
|
||||
const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
|
||||
span,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
|
||||
// 保存转换后的数据
|
||||
void window.api.trace.saveEntity(spanEntity)
|
||||
|
||||
logger.debug('AI SDK span converted and saved successfully (from startActiveSpan)', {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName,
|
||||
hasUsage: !!spanEntity.usage,
|
||||
usage: spanEntity.usage
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to convert AI SDK span (from startActiveSpan)', error as Error, {
|
||||
spanName: name,
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: this.topicId,
|
||||
modelName: this.modelName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return originalFn(passedSpan)
|
||||
}) as F
|
||||
return wrappedFn
|
||||
}
|
||||
|
||||
// 创建包含父 SpanContext 的上下文(如果有的话)
|
||||
const createContextWithParent = () => {
|
||||
if (this.cachedParentContext) {
|
||||
return this.cachedParentContext
|
||||
}
|
||||
if (this.parentSpanContext) {
|
||||
try {
|
||||
const ctx = trace.setSpanContext(otelContext.active(), this.parentSpanContext)
|
||||
logger.debug('Created active context with parent SpanContext for startActiveSpan', {
|
||||
spanName: name,
|
||||
parentTraceId: this.parentSpanContext.traceId,
|
||||
parentSpanId: this.parentSpanContext.spanId,
|
||||
topicId: this.topicId
|
||||
})
|
||||
return ctx
|
||||
} catch (error) {
|
||||
logger.warn('Failed to create context with parent SpanContext in startActiveSpan', error as Error)
|
||||
}
|
||||
}
|
||||
return otelContext.active()
|
||||
}
|
||||
|
||||
// 根据参数数量确定调用方式,注入包含mainTraceId的上下文
|
||||
if (typeof arg2 === 'function') {
|
||||
return this.originalTracer.startActiveSpan(name, {}, createContextWithParent(), (span: Span) => {
|
||||
return wrapFunction(arg2, span)(span)
|
||||
})
|
||||
} else if (typeof arg3 === 'function') {
|
||||
return this.originalTracer.startActiveSpan(name, arg2, createContextWithParent(), (span: Span) => {
|
||||
return wrapFunction(arg3, span)(span)
|
||||
})
|
||||
} else if (typeof arg4 === 'function') {
|
||||
// 如果调用方提供了 context,则保留以维护嵌套关系;否则回退到父上下文
|
||||
const ctx = arg3 ?? createContextWithParent()
|
||||
return this.originalTracer.startActiveSpan(name, arg2, ctx, (span: Span) => {
|
||||
return wrapFunction(arg4, span)(span)
|
||||
})
|
||||
} else {
|
||||
throw new Error('Invalid arguments for startActiveSpan')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function createTelemetryPlugin(config: TelemetryPluginConfig): AiPlugin<StreamTextParams, StreamTextResult> {
|
||||
const { enabled = true, recordInputs = true, recordOutputs = true, topicId } = config
|
||||
|
||||
return definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'telemetryPlugin',
|
||||
enforce: 'pre', // 在其他插件之前执行,确保 telemetry 配置被正确注入
|
||||
|
||||
transformParams: (params, context) => {
|
||||
if (!enabled) {
|
||||
return params
|
||||
}
|
||||
|
||||
// 获取共享的 tracer
|
||||
const originalTracer = webTraceService.getTracer()
|
||||
if (!originalTracer) {
|
||||
logger.warn('No tracer available from WebTraceService')
|
||||
return params
|
||||
}
|
||||
|
||||
// 获取topicId和modelName
|
||||
const effectiveTopicId = context.topicId || topicId
|
||||
// 使用与父span创建时一致的modelName - 应该是完整的modelId
|
||||
const modelName = config.assistant.model?.name || context.modelId
|
||||
|
||||
// 获取当前活跃的 span,确保 AI SDK spans 与手动 spans 在同一个 trace 中
|
||||
let parentSpan: Span | undefined = undefined
|
||||
let parentSpanContext: SpanContext | undefined = undefined
|
||||
|
||||
// 只有在有topicId时才尝试查找父span
|
||||
if (effectiveTopicId) {
|
||||
try {
|
||||
// 从 SpanManagerService 获取当前的 span
|
||||
logger.debug('Attempting to find parent span', {
|
||||
topicId: effectiveTopicId,
|
||||
requestId: context.requestId,
|
||||
modelName: modelName,
|
||||
contextModelId: context.modelId,
|
||||
providerId: context.providerId
|
||||
})
|
||||
|
||||
parentSpan = currentSpan(effectiveTopicId, modelName)
|
||||
if (parentSpan) {
|
||||
// 直接使用父 span 的 SpanContext,避免手动拼装字段遗漏
|
||||
parentSpanContext = parentSpan.spanContext()
|
||||
logger.debug('Found active parent span for AI SDK', {
|
||||
parentSpanId: parentSpanContext.spanId,
|
||||
parentTraceId: parentSpanContext.traceId,
|
||||
topicId: effectiveTopicId,
|
||||
requestId: context.requestId,
|
||||
modelName: modelName
|
||||
})
|
||||
} else {
|
||||
logger.warn('No active parent span found in SpanManagerService', {
|
||||
topicId: effectiveTopicId,
|
||||
requestId: context.requestId,
|
||||
modelId: context.modelId,
|
||||
modelName: modelName,
|
||||
providerId: context.providerId,
|
||||
// 更详细的调试信息
|
||||
searchedModelName: modelName,
|
||||
contextModelId: context.modelId,
|
||||
isAnalyzing: context.isAnalyzing
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error getting current span from SpanManagerService', error as Error, {
|
||||
topicId: effectiveTopicId,
|
||||
requestId: context.requestId,
|
||||
modelName: modelName
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.debug('No topicId provided, skipping parent span lookup', {
|
||||
requestId: context.requestId,
|
||||
contextTopicId: context.topicId,
|
||||
configTopicId: topicId,
|
||||
modelName: modelName
|
||||
})
|
||||
}
|
||||
|
||||
// 创建适配器包装的 tracer,传入获取到的父 SpanContext
|
||||
const adapterTracer = new AdapterTracer(originalTracer, effectiveTopicId, modelName, parentSpanContext)
|
||||
|
||||
// 注入 AI SDK telemetry 配置
|
||||
const telemetryConfig = {
|
||||
isEnabled: true,
|
||||
recordInputs,
|
||||
recordOutputs,
|
||||
tracer: adapterTracer,
|
||||
functionId: `ai-request-${context.requestId}`,
|
||||
metadata: {
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
topicId: effectiveTopicId,
|
||||
requestId: context.requestId,
|
||||
modelName: modelName,
|
||||
// 确保topicId也作为标准属性传递
|
||||
'trace.topicId': effectiveTopicId,
|
||||
'trace.modelName': modelName,
|
||||
// 添加父span信息用于调试(只在有值时添加)
|
||||
...(parentSpanContext?.spanId && { parentSpanId: parentSpanContext.spanId }),
|
||||
...(parentSpanContext?.traceId && { parentTraceId: parentSpanContext.traceId })
|
||||
}
|
||||
} satisfies TelemetrySettings
|
||||
|
||||
// 如果有父span,尝试在telemetry配置中设置父上下文
|
||||
if (parentSpan) {
|
||||
try {
|
||||
// 设置活跃上下文,确保 AI SDK spans 在正确的 trace 上下文中创建
|
||||
const activeContext = trace.setSpan(otelContext.active(), parentSpan)
|
||||
|
||||
// 更新全局上下文
|
||||
otelContext.with(activeContext, () => {
|
||||
logger.debug('Updated active context with parent span')
|
||||
})
|
||||
|
||||
logger.debug('Set parent context for AI SDK spans', {
|
||||
parentSpanId: parentSpanContext?.spanId,
|
||||
parentTraceId: parentSpanContext?.traceId,
|
||||
hasActiveContext: !!activeContext,
|
||||
hasParentSpan: !!parentSpan
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn('Failed to set parent context in telemetry config', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('Injecting AI SDK telemetry config with adapter', {
|
||||
requestId: context.requestId,
|
||||
topicId: effectiveTopicId,
|
||||
modelId: context.modelId,
|
||||
modelName: modelName,
|
||||
hasParentSpan: !!parentSpan,
|
||||
parentSpanId: parentSpanContext?.spanId,
|
||||
parentTraceId: parentSpanContext?.traceId,
|
||||
functionId: telemetryConfig.functionId,
|
||||
hasTracer: !!telemetryConfig.tracer,
|
||||
tracerType: telemetryConfig.tracer?.constructor?.name || 'unknown'
|
||||
})
|
||||
|
||||
return {
|
||||
...params,
|
||||
experimental_telemetry: telemetryConfig
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 默认导出便于使用
|
||||
export default createTelemetryPlugin
|
||||
@@ -1,838 +0,0 @@
|
||||
import type { Message, Model } from '@renderer/types'
|
||||
import type { FileMetadata } from '@renderer/types/file'
|
||||
import { FILE_TYPE } from '@renderer/types/file'
|
||||
import {
|
||||
AssistantMessageStatus,
|
||||
type FileMessageBlock,
|
||||
type ImageMessageBlock,
|
||||
type MainTextMessageBlock,
|
||||
MessageBlockStatus,
|
||||
MessageBlockType,
|
||||
type ThinkingMessageBlock,
|
||||
UserMessageStatus
|
||||
} from '@renderer/types/newMessage'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
|
||||
convertFileBlockToFilePartMock: vi.fn(),
|
||||
convertFileBlockToTextPartMock: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../fileProcessor', () => ({
|
||||
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
|
||||
convertFileBlockToTextPart: convertFileBlockToTextPartMock
|
||||
}))
|
||||
|
||||
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
|
||||
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isVisionModel: (model: Model) => visionModelIds.has(model.id),
|
||||
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
|
||||
}))
|
||||
|
||||
type MockableMessage = Message & {
|
||||
__mockContent?: string
|
||||
__mockFileBlocks?: FileMessageBlock[]
|
||||
__mockImageBlocks?: ImageMessageBlock[]
|
||||
__mockThinkingBlocks?: ThinkingMessageBlock[]
|
||||
__mockMainTextBlocks?: MainTextMessageBlock[]
|
||||
}
|
||||
|
||||
vi.mock('@renderer/utils/messageUtils/find', () => ({
|
||||
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
|
||||
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
|
||||
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
|
||||
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? [],
|
||||
findMainTextBlocks: (message: Message) => (message as MockableMessage).__mockMainTextBlocks ?? []
|
||||
}))
|
||||
|
||||
import { convertMessagesToSdkMessages, convertMessageToSdkParam, stripMarkdownBase64Images } from '../messageConverter'
|
||||
|
||||
let messageCounter = 0
|
||||
let blockCounter = 0
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o-mini',
|
||||
name: 'GPT-4o mini',
|
||||
provider: 'openai',
|
||||
group: 'openai',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createMessage = (role: Message['role']): MockableMessage =>
|
||||
({
|
||||
id: `message-${++messageCounter}`,
|
||||
role,
|
||||
assistantId: 'assistant-1',
|
||||
topicId: 'topic-1',
|
||||
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
|
||||
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}) as MockableMessage
|
||||
|
||||
const createFileBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
|
||||
): FileMessageBlock => {
|
||||
const { file, ...blockOverrides } = overrides
|
||||
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
|
||||
return {
|
||||
id: blockOverrides.id ?? `file-block-${blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.FILE,
|
||||
createdAt: blockOverrides.createdAt ?? timestamp,
|
||||
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
file: {
|
||||
id: file?.id ?? `file-${blockCounter}`,
|
||||
name: file?.name ?? 'document.txt',
|
||||
origin_name: file?.origin_name ?? 'document.txt',
|
||||
path: file?.path ?? '/tmp/document.txt',
|
||||
size: file?.size ?? 1024,
|
||||
ext: file?.ext ?? '.txt',
|
||||
type: file?.type ?? FILE_TYPE.TEXT,
|
||||
created_at: file?.created_at ?? timestamp,
|
||||
count: file?.count ?? 1,
|
||||
...file
|
||||
},
|
||||
...blockOverrides
|
||||
}
|
||||
}
|
||||
|
||||
const createImageBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
|
||||
): ImageMessageBlock => ({
|
||||
id: overrides.id ?? `image-block-${++blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.IMAGE,
|
||||
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
||||
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
url: overrides.url ?? 'https://example.com/image.png',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createThinkingBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<ThinkingMessageBlock, 'type' | 'messageId'>> = {}
|
||||
): ThinkingMessageBlock => ({
|
||||
id: overrides.id ?? `thinking-block-${++blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.THINKING,
|
||||
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
||||
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
content: overrides.content ?? 'Let me think...',
|
||||
thinking_millsec: overrides.thinking_millsec ?? 1000,
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createMainTextBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<MainTextMessageBlock, 'type' | 'messageId'>> = {}
|
||||
): MainTextMessageBlock => ({
|
||||
id: overrides.id ?? `main-text-block-${++blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
||||
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
content: overrides.content ?? '',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('messageConverter', () => {
|
||||
beforeEach(() => {
|
||||
convertFileBlockToFilePartMock.mockReset()
|
||||
convertFileBlockToTextPartMock.mockReset()
|
||||
convertFileBlockToFilePartMock.mockResolvedValue(null)
|
||||
convertFileBlockToTextPartMock.mockResolvedValue(null)
|
||||
messageCounter = 0
|
||||
blockCounter = 0
|
||||
})
|
||||
|
||||
describe('convertMessageToSdkParam', () => {
|
||||
it('includes text and image parts for user messages on vision models', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Describe this picture'
|
||||
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, true, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Describe this picture' },
|
||||
{ type: 'image', image: 'https://example.com/cat.png' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('extracts base64 data from data URLs and preserves mediaType', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Check this image'
|
||||
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:image/png;base64,iVBORw0KGgoAAAANS' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, true, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Check this image' },
|
||||
{ type: 'image', image: 'iVBORw0KGgoAAAANS', mediaType: 'image/png' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('handles data URLs without mediaType gracefully', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Check this'
|
||||
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:;base64,AAABBBCCC' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, true, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Check this' },
|
||||
{ type: 'image', image: 'AAABBBCCC' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('skips malformed data URLs without comma separator', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Malformed data url'
|
||||
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'data:image/pngAAABBB' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, true, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Malformed data url' }
|
||||
// Malformed data URL is excluded from the content
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('handles multiple large base64 images without stack overflow', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
// Create large base64 strings (~500KB each) to simulate real-world large images
|
||||
const largeBase64 = 'A'.repeat(500_000)
|
||||
message.__mockContent = 'Check these images'
|
||||
message.__mockImageBlocks = [
|
||||
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }),
|
||||
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` }),
|
||||
createImageBlock(message.id, { url: `data:image/png;base64,${largeBase64}` })
|
||||
]
|
||||
|
||||
// Should not throw RangeError: Maximum call stack size exceeded
|
||||
await expect(convertMessageToSdkParam(message, true, model)).resolves.toBeDefined()
|
||||
})
|
||||
|
||||
it('returns file instructions as a system message when native uploads succeed', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Summarize the PDF'
|
||||
message.__mockFileBlocks = [createFileBlock(message.id)]
|
||||
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||
type: 'file',
|
||||
filename: 'document.pdf',
|
||||
mediaType: 'application/pdf',
|
||||
data: 'fileid://remote-file'
|
||||
})
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'system',
|
||||
content: 'fileid://remote-file'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Summarize the PDF' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('includes reasoning parts for assistant messages with thinking blocks', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = 'Here is my answer'
|
||||
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Let me think...' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
// Reasoning blocks must come before text blocks (required by AWS Bedrock for Claude extended thinking)
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'reasoning', text: 'Let me think...' },
|
||||
{ type: 'text', text: 'Here is my answer' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('excludes empty content from assistant messages', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = ''
|
||||
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking only' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
// Empty content should not create a text block
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [{ type: 'reasoning', text: 'Thinking only' }]
|
||||
})
|
||||
})
|
||||
|
||||
it('excludes whitespace-only content from assistant messages', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = ' \n\t '
|
||||
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking only' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
// Whitespace-only content should not create a text block
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [{ type: 'reasoning', text: 'Thinking only' }]
|
||||
})
|
||||
})
|
||||
|
||||
it('trims content in assistant messages', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = ' Trimmed answer \n'
|
||||
message.__mockThinkingBlocks = []
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Trimmed answer' }]
|
||||
})
|
||||
})
|
||||
|
||||
it('includes thoughtSignature in providerOptions for Gemini thought signature persistence', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = 'Here is my answer'
|
||||
message.__mockMainTextBlocks = [
|
||||
createMainTextBlock(message.id, {
|
||||
content: 'Here is my answer',
|
||||
metadata: { thoughtSignature: 'test-thought-signature-token' }
|
||||
})
|
||||
]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Here is my answer',
|
||||
providerOptions: {
|
||||
google: {
|
||||
thoughtSignature: 'test-thought-signature-token'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('does not include providerOptions when no thoughtSignature is present', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = 'Plain answer'
|
||||
message.__mockMainTextBlocks = [createMainTextBlock(message.id, { content: 'Plain answer' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Plain answer' }]
|
||||
})
|
||||
})
|
||||
|
||||
it('uses thoughtSignature from the first matching MainTextBlock when multiple exist', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = 'Answer text'
|
||||
message.__mockMainTextBlocks = [
|
||||
createMainTextBlock(message.id, { content: 'Answer text', metadata: { thoughtSignature: 'first-signature' } }),
|
||||
createMainTextBlock(message.id, {
|
||||
content: 'Another block',
|
||||
metadata: { thoughtSignature: 'second-signature' }
|
||||
})
|
||||
]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Answer text',
|
||||
providerOptions: {
|
||||
google: {
|
||||
thoughtSignature: 'first-signature'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('combines reasoning blocks with thoughtSignature text part', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('assistant')
|
||||
message.__mockContent = 'Final answer'
|
||||
message.__mockThinkingBlocks = [createThinkingBlock(message.id, { content: 'Thinking step' })]
|
||||
message.__mockMainTextBlocks = [
|
||||
createMainTextBlock(message.id, {
|
||||
content: 'Final answer',
|
||||
metadata: { thoughtSignature: 'sig-with-reasoning' }
|
||||
})
|
||||
]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'reasoning', text: 'Thinking step' },
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Final answer',
|
||||
providerOptions: {
|
||||
google: {
|
||||
thoughtSignature: 'sig-with-reasoning'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertMessagesToSdkMessages', () => {
|
||||
it('preserves conversation history and merges images for image enhancement models', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const initialUser = createMessage('user')
|
||||
initialUser.__mockContent = 'Start editing'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Here is the current preview'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
|
||||
|
||||
const finalUser = createMessage('user')
|
||||
finalUser.__mockContent = 'Increase the brightness'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
|
||||
|
||||
// Preserves all conversation history, only merges images into the last user message
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start editing' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Here is the current preview' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Increase the brightness' },
|
||||
{ type: 'image', image: 'https://example.com/preview.png' }
|
||||
]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('preserves system messages and conversation history for enhancement payloads', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const fileUser = createMessage('user')
|
||||
fileUser.__mockContent = 'Use this document as inspiration'
|
||||
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FILE_TYPE.DOCUMENT } })]
|
||||
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||
type: 'file',
|
||||
filename: 'reference.pdf',
|
||||
mediaType: 'application/pdf',
|
||||
data: 'fileid://reference'
|
||||
})
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Generated previews ready'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
|
||||
|
||||
const finalUser = createMessage('user')
|
||||
finalUser.__mockContent = 'Apply the edits'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
|
||||
|
||||
// Preserves system message, conversation history, and merges images into the last user message
|
||||
expect(result).toEqual([
|
||||
{ role: 'system', content: 'fileid://reference' },
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Use this document as inspiration' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Generated previews ready' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Apply the edits' },
|
||||
{ type: 'image', image: 'https://example.com/reference.png' }
|
||||
]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('returns messages as-is when no previous assistant message with images', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Start'
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = 'Continue without images'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, user2], model)
|
||||
|
||||
// No images to merge, returns all messages as-is
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Continue without images' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('returns messages as-is when assistant message has no images', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Start'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Text only response'
|
||||
assistant.__mockImageBlocks = []
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = 'Follow up'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
|
||||
|
||||
// No images to merge, returns all messages as-is
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Text only response' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Follow up' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('merges images from the most recent assistant message', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Start'
|
||||
|
||||
const assistant1 = createMessage('assistant')
|
||||
assistant1.__mockContent = 'First response'
|
||||
assistant1.__mockImageBlocks = [createImageBlock(assistant1.id, { url: 'https://example.com/old.png' })]
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = 'Continue'
|
||||
|
||||
const assistant2 = createMessage('assistant')
|
||||
assistant2.__mockContent = 'Second response'
|
||||
assistant2.__mockImageBlocks = [createImageBlock(assistant2.id, { url: 'https://example.com/new.png' })]
|
||||
|
||||
const user3 = createMessage('user')
|
||||
user3.__mockContent = 'Final request'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, assistant1, user2, assistant2, user3], model)
|
||||
|
||||
// Preserves all history, merges only the most recent assistant's images
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'First response' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Continue' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Second response' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Final request' },
|
||||
{ type: 'image', image: 'https://example.com/new.png' }
|
||||
]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('returns messages as-is when conversation ends with assistant message', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const user = createMessage('user')
|
||||
user.__mockContent = 'Start'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Response with image'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/image.png' })]
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user, assistant], model)
|
||||
|
||||
// The user message is the last user message, but since the assistant comes after,
|
||||
// there's no "previous" assistant message (search starts from messages.length-2 backwards)
|
||||
// So no images to merge, returns all messages as-is
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Response with image' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('merges images even when last user message has empty content', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Start'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Here is the preview'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = ''
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
|
||||
|
||||
// Preserves history, merges images into last user message (even if empty)
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Here is the preview' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'image', image: 'https://example.com/preview.png' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('strips inline base64 data URIs from assistant text to prevent HTTP 413 (#12602)', async () => {
|
||||
const model = createModel({ id: 'gpt-4o-mini' })
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Generate an image of a cat'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent =
|
||||
'Here is the image you requested:\n\nHope you like it!'
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = 'Now describe what you see'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, assistant, user2], model)
|
||||
|
||||
const assistantMsg = result.find((m) => m.role === 'assistant')!
|
||||
const textPart = (assistantMsg.content as Array<{ type: string; text: string }>).find((p) => p.type === 'text')!
|
||||
// The base64 data URI should be replaced with a placeholder
|
||||
expect(textPart.text).not.toContain('data:image/')
|
||||
expect(textPart.text).toContain('')
|
||||
expect(textPart.text).toContain('Hope you like it!')
|
||||
})
|
||||
|
||||
it('strips multiple inline base64 images from assistant text', async () => {
|
||||
const model = createModel({ id: 'gpt-4o-mini' })
|
||||
const user = createMessage('user')
|
||||
user.__mockContent = 'Generate two images'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = ' and '
|
||||
|
||||
const result = await convertMessageToSdkParam(assistant, false, model)
|
||||
const textPart = ((result as any).content as Array<{ type: string; text: string }>).find(
|
||||
(p) => p.type === 'text'
|
||||
)!
|
||||
expect(textPart.text).toBe(' and ')
|
||||
})
|
||||
|
||||
it('preserves regular markdown images (non-base64) in assistant text', async () => {
|
||||
const model = createModel({ id: 'gpt-4o-mini' })
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Check this out: '
|
||||
|
||||
const result = await convertMessageToSdkParam(assistant, false, model)
|
||||
const textPart = ((result as any).content as Array<{ type: string; text: string }>).find(
|
||||
(p) => p.type === 'text'
|
||||
)!
|
||||
expect(textPart.text).toBe('Check this out: ')
|
||||
})
|
||||
|
||||
it('allows using LLM conversation context for image generation', async () => {
|
||||
// This test verifies the key use case: switching from LLM to image enhancement model
|
||||
// and using the previous conversation as context for image generation
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
|
||||
// Simulate a conversation that started with a regular LLM
|
||||
const user1 = createMessage('user')
|
||||
user1.__mockContent = 'Help me design a futuristic robot with blue lights'
|
||||
|
||||
const assistant1 = createMessage('assistant')
|
||||
assistant1.__mockContent =
|
||||
'Great idea! The robot could have a sleek metallic body with glowing blue LED strips...'
|
||||
assistant1.__mockImageBlocks = [] // LLM response, no images
|
||||
|
||||
const user2 = createMessage('user')
|
||||
user2.__mockContent = 'Yes, and add some chrome accents'
|
||||
|
||||
const assistant2 = createMessage('assistant')
|
||||
assistant2.__mockContent = 'Perfect! Chrome accents would complement the blue lights beautifully...'
|
||||
assistant2.__mockImageBlocks = [] // Still LLM response, no images
|
||||
|
||||
// User switches to image enhancement model and asks for image generation
|
||||
const user3 = createMessage('user')
|
||||
user3.__mockContent = 'Now generate an image based on our discussion'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([user1, assistant1, user2, assistant2, user3], model)
|
||||
|
||||
// All conversation history should be preserved for context
|
||||
// No images to merge since previous assistant had no images
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Help me design a futuristic robot with blue lights' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Great idea! The robot could have a sleek metallic body with glowing blue LED strips...'
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Yes, and add some chrome accents' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Perfect! Chrome accents would complement the blue lights beautifully...' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Now generate an image based on our discussion' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('stripMarkdownBase64Images', () => {
|
||||
it('replaces a single base64 image with placeholder', () => {
|
||||
const input = 'Here is the image:\n\nDone.'
|
||||
expect(stripMarkdownBase64Images(input)).toBe('Here is the image:\n\nDone.')
|
||||
})
|
||||
|
||||
it('replaces multiple base64 images', () => {
|
||||
const input = ' text '
|
||||
expect(stripMarkdownBase64Images(input)).toBe(' text ')
|
||||
})
|
||||
|
||||
it('preserves regular markdown images with http URLs', () => {
|
||||
const input = ''
|
||||
expect(stripMarkdownBase64Images(input)).toBe(input)
|
||||
})
|
||||
|
||||
it('preserves file:// URLs in markdown images', () => {
|
||||
const input = ''
|
||||
expect(stripMarkdownBase64Images(input)).toBe(input)
|
||||
})
|
||||
|
||||
it('handles empty alt text', () => {
|
||||
const input = ''
|
||||
expect(stripMarkdownBase64Images(input)).toBe('')
|
||||
})
|
||||
|
||||
it('handles text with no markdown images', () => {
|
||||
expect(stripMarkdownBase64Images('Just plain text.')).toBe('Just plain text.')
|
||||
})
|
||||
|
||||
it('returns empty string for empty input', () => {
|
||||
expect(stripMarkdownBase64Images('')).toBe('')
|
||||
})
|
||||
|
||||
it('handles mixed base64 and regular images', () => {
|
||||
const input =
|
||||
' then  then '
|
||||
expect(stripMarkdownBase64Images(input)).toBe(
|
||||
' then  then '
|
||||
)
|
||||
})
|
||||
|
||||
it('handles data URI without base64 encoding', () => {
|
||||
const input = ''
|
||||
expect(stripMarkdownBase64Images(input)).toBe('')
|
||||
})
|
||||
|
||||
it('does not treat bare ](data: without  more text'
|
||||
expect(stripMarkdownBase64Images(input)).toBe(input)
|
||||
})
|
||||
|
||||
it('handles large base64 payload without OOM', () => {
|
||||
const largeBase64 = 'A'.repeat(5_000_000)
|
||||
const input = ``
|
||||
expect(stripMarkdownBase64Images(input)).toBe('')
|
||||
})
|
||||
|
||||
it('handles unclosed parenthesis gracefully', () => {
|
||||
const input = ').toBe(input)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,317 +0,0 @@
|
||||
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
|
||||
import { TopicType } from '@renderer/types'
|
||||
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getMaxTokens, getTemperature, getTimeout, getTopP } from '../modelParameters'
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
DEFAULT_ASSISTANT_SETTINGS: {
|
||||
maxTokens: 4096,
|
||||
enableMaxTokens: false,
|
||||
temperature: 0.7,
|
||||
enableTemperature: true,
|
||||
topP: 1,
|
||||
enableTopP: false,
|
||||
contextCount: 4096,
|
||||
streamOutput: true,
|
||||
defaultModel: undefined,
|
||||
customParameters: [],
|
||||
reasoning_effort: 'default',
|
||||
qwenThinkMode: undefined,
|
||||
toolUseMode: 'function',
|
||||
maxToolCalls: 20,
|
||||
enableMaxToolCalls: true
|
||||
},
|
||||
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
|
||||
contextCount: assistant.settings?.contextCount ?? 4096,
|
||||
temperature: assistant.settings?.temperature ?? 0.7,
|
||||
enableTemperature: assistant.settings?.enableTemperature ?? true,
|
||||
topP: assistant.settings?.topP ?? 1,
|
||||
enableTopP: assistant.settings?.enableTopP ?? false,
|
||||
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
|
||||
maxTokens: assistant.settings?.maxTokens,
|
||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
|
||||
defaultModel: assistant.defaultModel,
|
||||
customParameters: assistant.settings?.customParameters ?? [],
|
||||
reasoning_effort: assistant.settings?.reasoning_effort ?? 'default',
|
||||
qwenThinkMode: assistant.settings?.qwenThinkMode
|
||||
}),
|
||||
getProviderByModel: (model: Model) => ({ id: model.provider, type: model.provider, models: [] })
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn(),
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/assistants', () => ({
|
||||
default: (state = { assistants: [] }) => state
|
||||
}))
|
||||
|
||||
const createTopic = (assistantId: string): Topic => ({
|
||||
id: `topic-${assistantId}`,
|
||||
assistantId,
|
||||
name: 'topic',
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messages: [],
|
||||
type: TopicType.Chat
|
||||
})
|
||||
|
||||
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
|
||||
const assistantId = 'assistant-1'
|
||||
return {
|
||||
id: assistantId,
|
||||
name: 'Test Assistant',
|
||||
prompt: 'prompt',
|
||||
topics: [createTopic(assistantId)],
|
||||
type: 'assistant',
|
||||
settings
|
||||
}
|
||||
}
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
name: 'GPT-4o',
|
||||
group: 'openai',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('modelParameters', () => {
|
||||
describe('getTemperature', () => {
|
||||
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'medium', enableTemperature: true })
|
||||
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns temperature when reasoning effort is default for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'default', enableTemperature: true, temperature: 0.7 })
|
||||
const model = createModel({ id: 'claude-sonnet-4.5', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.7)
|
||||
})
|
||||
|
||||
it('returns temperature when reasoning effort is none for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'none', enableTemperature: true, temperature: 0.5 })
|
||||
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.5)
|
||||
})
|
||||
|
||||
it('returns undefined for models without temperature/topP support', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true })
|
||||
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
|
||||
const model = createModel({
|
||||
id: 'claude-sonnet-4.5',
|
||||
name: 'Claude Sonnet 4.5',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns configured temperature when enabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.42)
|
||||
})
|
||||
|
||||
it('returns undefined when temperature is disabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Zhipu models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Anthropic models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
|
||||
const model = createModel({
|
||||
id: 'claude-sonnet-3.5',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Moonshot models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({
|
||||
id: 'moonshot-v1-8k',
|
||||
name: 'Moonshot v1 8k',
|
||||
provider: 'moonshot',
|
||||
group: 'moonshot'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('does not clamp temperature for OpenAI models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(2.0)
|
||||
})
|
||||
|
||||
it('does not clamp temperature when it is already within limits', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
|
||||
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTopP', () => {
|
||||
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'high' })
|
||||
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for models without TopP support', () => {
|
||||
const assistant = createAssistant({ enableTopP: true })
|
||||
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true })
|
||||
const model = createModel({
|
||||
id: 'claude-opus-4.5',
|
||||
name: 'Claude Opus 4.5',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns configured TopP when enabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.73)
|
||||
})
|
||||
|
||||
it('returns undefined when TopP is disabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('clamps topP to [0.95, 1] for Claude reasoning models with reasoning effort', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'high' })
|
||||
const model = createModel({ id: 'claude-sonnet-4.5', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.95)
|
||||
})
|
||||
|
||||
it('does not clamp topP when reasoning effort is default for Claude models', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'default' })
|
||||
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.5)
|
||||
})
|
||||
|
||||
it('does not clamp topP when reasoning effort is none for Claude models', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.5, reasoning_effort: 'none' })
|
||||
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.5)
|
||||
})
|
||||
|
||||
it('keeps topP unchanged when already in [0.95, 1] range for Claude reasoning models', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.97, reasoning_effort: 'medium' })
|
||||
const model = createModel({ id: 'claude-sonnet-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.97)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTimeout', () => {
|
||||
it('uses an extended timeout for flex service tier models', () => {
|
||||
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTimeout(model)).toBe(15 * 1000 * 60)
|
||||
})
|
||||
|
||||
it('falls back to the default timeout otherwise', () => {
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTimeout(model)).toBe(DEFAULT_TIMEOUT)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getMaxTokens', () => {
|
||||
it('returns undefined when maxTokens is not enabled', () => {
|
||||
const assistant = createAssistant({ enableMaxTokens: false, maxTokens: 128000 })
|
||||
const model = createModel({ id: 'claude-opus-4-6', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getMaxTokens(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns user-configured maxTokens for Claude 4.6 without subtraction', () => {
|
||||
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 128000 })
|
||||
const model = createModel({ id: 'claude-opus-4-6', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getMaxTokens(assistant, model)).toBe(128000)
|
||||
})
|
||||
|
||||
it('returns user-configured maxTokens for Claude Sonnet 4.6 without subtraction', () => {
|
||||
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 64000 })
|
||||
const model = createModel({ id: 'claude-sonnet-4-6', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getMaxTokens(assistant, model)).toBe(64000)
|
||||
})
|
||||
|
||||
it('subtracts thinking budget for non-4.6 Claude models with anthropic provider', () => {
|
||||
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 16384 })
|
||||
const model = createModel({ id: 'claude-sonnet-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
const result = getMaxTokens(assistant, model)
|
||||
// Non-4.6 Claude thinking models should have budget subtracted
|
||||
expect(result).toBeDefined()
|
||||
expect(result!).toBeLessThan(16384)
|
||||
})
|
||||
|
||||
it('returns maxTokens as-is for non-Claude models', () => {
|
||||
const assistant = createAssistant({ enableMaxTokens: true, maxTokens: 4096 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getMaxTokens(assistant, model)).toBe(4096)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,225 +0,0 @@
|
||||
/**
|
||||
* Tests for parameterBuilder maxToolCalls functionality
|
||||
* These tests verify the maxToolCalls calculation and validation logic in isolation
|
||||
*/
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
// Mirror the constants from parameterBuilder.ts
|
||||
const MIN_TOOL_CALLS = 1
|
||||
const MAX_TOOL_CALLS = 100
|
||||
const DEFAULT_MAX_TOOL_CALLS = 20
|
||||
const DEFAULT_ENABLE_MAX_TOOL_CALLS = true
|
||||
|
||||
/**
|
||||
* Validates and clamps maxToolCalls to valid range
|
||||
* Mirrors the logic in parameterBuilder.ts
|
||||
*/
|
||||
function validateMaxToolCalls(value: number | undefined): number {
|
||||
if (value === undefined || value < MIN_TOOL_CALLS || value > MAX_TOOL_CALLS) {
|
||||
return DEFAULT_MAX_TOOL_CALLS
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the effective max tool calls based on assistant settings
|
||||
* Mirrors the logic in parameterBuilder.ts
|
||||
*/
|
||||
function calculateEffectiveMaxToolCalls(settings?: { maxToolCalls?: number; enableMaxToolCalls?: boolean }): {
|
||||
stopWhen: number | null
|
||||
maxToolCalls: number
|
||||
} {
|
||||
const enableMaxToolCalls = settings?.enableMaxToolCalls ?? DEFAULT_ENABLE_MAX_TOOL_CALLS
|
||||
|
||||
if (!enableMaxToolCalls) {
|
||||
// When disabled, don't pass stopWhen (return null to indicate no stopWhen)
|
||||
return { stopWhen: null, maxToolCalls: DEFAULT_MAX_TOOL_CALLS }
|
||||
}
|
||||
|
||||
// When enabled, validate and use user-defined value
|
||||
const maxToolCalls = validateMaxToolCalls(settings?.maxToolCalls)
|
||||
return { stopWhen: maxToolCalls, maxToolCalls }
|
||||
}
|
||||
|
||||
describe('validateMaxToolCalls', () => {
|
||||
it('returns valid values as-is', () => {
|
||||
expect(validateMaxToolCalls(1)).toBe(1)
|
||||
expect(validateMaxToolCalls(50)).toBe(50)
|
||||
expect(validateMaxToolCalls(100)).toBe(100)
|
||||
})
|
||||
|
||||
it('clamps values above 100 to default', () => {
|
||||
const result = validateMaxToolCalls(999)
|
||||
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('clamps zero to default', () => {
|
||||
const result = validateMaxToolCalls(0)
|
||||
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('clamps negative values to default', () => {
|
||||
const result = validateMaxToolCalls(-5)
|
||||
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('returns default when value is undefined', () => {
|
||||
const result = validateMaxToolCalls(undefined)
|
||||
expect(result).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
})
|
||||
|
||||
describe('maxToolCalls calculation logic', () => {
|
||||
describe('default behavior', () => {
|
||||
it('uses default value 20 when settings are undefined', () => {
|
||||
const result = calculateEffectiveMaxToolCalls(undefined)
|
||||
expect(result.maxToolCalls).toBe(20)
|
||||
expect(result.stopWhen).toBe(20)
|
||||
})
|
||||
|
||||
it('uses default value 20 when settings is empty object', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({})
|
||||
expect(result.maxToolCalls).toBe(20)
|
||||
expect(result.stopWhen).toBe(20)
|
||||
})
|
||||
|
||||
it('uses default value 20 when maxToolCalls is undefined', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true
|
||||
// maxToolCalls is undefined
|
||||
})
|
||||
expect(result.maxToolCalls).toBe(20)
|
||||
expect(result.stopWhen).toBe(20)
|
||||
})
|
||||
|
||||
it('uses custom value when maxToolCalls is set and enabled', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 50
|
||||
})
|
||||
expect(result.maxToolCalls).toBe(50)
|
||||
expect(result.stopWhen).toBe(50)
|
||||
})
|
||||
})
|
||||
|
||||
describe('custom values when enabled', () => {
|
||||
it('uses custom value when enableMaxToolCalls is true and maxToolCalls is set', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 50
|
||||
})
|
||||
expect(result.stopWhen).toBe(50)
|
||||
})
|
||||
|
||||
it('uses custom value at minimum boundary (1)', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 1
|
||||
})
|
||||
expect(result.stopWhen).toBe(1)
|
||||
})
|
||||
|
||||
it('uses custom value at maximum boundary (100)', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 100
|
||||
})
|
||||
expect(result.stopWhen).toBe(100)
|
||||
})
|
||||
|
||||
it('clamps values above 100 to default when enabled', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 999
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('clamps zero to default when enabled', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 0
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('clamps negative values to default when enabled', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: -5
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
})
|
||||
|
||||
describe('disabled behavior', () => {
|
||||
it('does not pass stopWhen when enableMaxToolCalls is false', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: false,
|
||||
maxToolCalls: 50
|
||||
})
|
||||
// When disabled, stopWhen should be null (indicating no stopWhen passed)
|
||||
expect(result.stopWhen).toBeNull()
|
||||
})
|
||||
|
||||
it('does not pass stopWhen when both enableMaxToolCalls is false and maxToolCalls is undefined', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: false
|
||||
})
|
||||
expect(result.stopWhen).toBeNull()
|
||||
})
|
||||
|
||||
it('falls back to default when disabled with invalid maxToolCalls', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: false,
|
||||
maxToolCalls: 999
|
||||
})
|
||||
// When disabled, maxToolCalls should still be default (for reference)
|
||||
expect(result.maxToolCalls).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
expect(result.stopWhen).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('backward compatibility', () => {
|
||||
it('maintains backward compatibility - existing assistants without new fields use default', () => {
|
||||
// Simulate an old assistant without the new fields
|
||||
const oldSettings = {
|
||||
// Old assistants don't have enableMaxToolCalls or maxToolCalls
|
||||
temperature: 0.7,
|
||||
contextCount: 10
|
||||
}
|
||||
const result = calculateEffectiveMaxToolCalls(
|
||||
oldSettings as { maxToolCalls?: number; enableMaxToolCalls?: boolean }
|
||||
)
|
||||
// Should default to enabled with 20 for backward compatibility
|
||||
expect(result.maxToolCalls).toBe(20)
|
||||
expect(result.stopWhen).toBe(20)
|
||||
})
|
||||
})
|
||||
|
||||
describe('security - invalid values from imported/migrated settings', () => {
|
||||
it('validates extremely large values from imported settings', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 999999
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('validates negative values from imported settings', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: -100
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
|
||||
it('validates zero from imported settings', () => {
|
||||
const result = calculateEffectiveMaxToolCalls({
|
||||
enableMaxToolCalls: true,
|
||||
maxToolCalls: 0
|
||||
})
|
||||
expect(result.stopWhen).toBe(DEFAULT_MAX_TOOL_CALLS)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,283 +0,0 @@
|
||||
/**
|
||||
* 文件处理模块
|
||||
* 处理文件内容提取、文件格式转换、文件上传等逻辑
|
||||
*/
|
||||
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { FileMetadata, Message, Model } from '@renderer/types'
|
||||
import { FILE_TYPE } from '@renderer/types'
|
||||
import type { FileMessageBlock } from '@renderer/types/newMessage'
|
||||
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import type { FilePart, TextPart } from 'ai'
|
||||
import i18n from 'i18next'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload } from './modelCapabilities'
|
||||
|
||||
const logger = loggerService.withContext('fileProcessor')
|
||||
|
||||
/**
|
||||
* 提取文件内容
|
||||
*/
|
||||
export async function extractFileContent(message: Message): Promise<string> {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => fb.file.type === type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文件块转换为文本部分
|
||||
*/
|
||||
export async function convertFileBlockToTextPart(fileBlock: FileMessageBlock): Promise<TextPart | null> {
|
||||
const file = fileBlock.file
|
||||
|
||||
// 处理文本文件
|
||||
if (file.type === FILE_TYPE.TEXT) {
|
||||
try {
|
||||
const fileContent = await window.api.file.read(file.id + file.ext)
|
||||
return {
|
||||
type: 'text',
|
||||
text: `${file.origin_name}\n${fileContent.trim()}`
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to read text file:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理文档文件(PDF、Word、Excel等)- 提取为文本内容
|
||||
if (file.type === FILE_TYPE.DOCUMENT) {
|
||||
try {
|
||||
const fileContent = await window.api.file.read(file.id + file.ext, true) // true表示强制文本提取
|
||||
return {
|
||||
type: 'text',
|
||||
text: `${file.origin_name}\n${fileContent.trim()}`
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to extract text from document ${file.origin_name}:`, error as Error)
|
||||
window.toast.error(i18n.t('message.error.file.text_extraction_failed', { name: file.origin_name }))
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理Gemini大文件上传
|
||||
*/
|
||||
export async function handleGeminiFileUpload(file: FileMetadata, model: Model): Promise<FilePart | null> {
|
||||
try {
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
// 检查文件是否已经上传过
|
||||
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
|
||||
|
||||
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
|
||||
const remoteFile = fileMetadata.originalFile.file as any // 临时类型断言,因为File类型定义可能不完整
|
||||
// 注意:AI SDK的FilePart格式和Gemini原生格式不同,这里需要适配
|
||||
// 暂时返回null让它回退到文本处理,或者需要扩展FilePart支持uri
|
||||
logger.info(`File ${file.origin_name} already uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`)
|
||||
return null
|
||||
}
|
||||
|
||||
// 如果文件未上传,执行上传
|
||||
const uploadResult = await window.api.fileService.upload(provider, file)
|
||||
if (uploadResult.originalFile?.file) {
|
||||
const remoteFile = uploadResult.originalFile.file as any // 临时类型断言
|
||||
logger.info(`File ${file.origin_name} uploaded to Gemini with URI: ${remoteFile.uri || 'unknown'}`)
|
||||
// 同样,这里需要处理URI格式的文件引用
|
||||
return null
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to upload file ${file.origin_name} to Gemini:`, error as Error)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理OpenAI兼容大文件上传
|
||||
*/
|
||||
export async function handleOpenAILargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果模型为qwen-long系列,文档中要求purpose需要为'file-extract'
|
||||
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
|
||||
file = {
|
||||
...file,
|
||||
// 该类型并不在OpenAI定义中,但符合sdk规范,强制断言
|
||||
purpose: 'file-extract' as OpenAI.FilePurpose
|
||||
}
|
||||
}
|
||||
try {
|
||||
// 检查文件是否已经上传过
|
||||
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
|
||||
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
|
||||
// 判断用途是否一致
|
||||
if (remoteFile.purpose !== file.purpose) {
|
||||
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
|
||||
throw new Error('File purpose mismatch')
|
||||
}
|
||||
return {
|
||||
type: 'file',
|
||||
filename: file.origin_name,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
|
||||
return null
|
||||
}
|
||||
try {
|
||||
// 如果文件未上传,执行上传
|
||||
const uploadResult = await window.api.fileService.upload(provider, file)
|
||||
if (uploadResult.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
|
||||
logger.info(`File ${file.origin_name} uploaded.`)
|
||||
return {
|
||||
type: 'file',
|
||||
filename: remoteFile.filename,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 大文件上传路由函数
|
||||
*/
|
||||
export async function handleLargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
if (['google', 'google-vertex'].includes(aiSdkId)) {
|
||||
return await handleGeminiFileUpload(file, model)
|
||||
}
|
||||
|
||||
if (provider.type === 'openai') {
|
||||
return await handleOpenAILargeFileUpload(file, model)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文件块转换为FilePart(用于原生文件支持)
|
||||
*/
|
||||
export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, model: Model): Promise<FilePart | null> {
|
||||
const file = fileBlock.file
|
||||
const fileSizeLimit = getFileSizeLimit(model, file.type)
|
||||
|
||||
try {
|
||||
// 处理PDF文档(始终生成 FilePart,由下游插件处理兼容性)
|
||||
if (file.type === FILE_TYPE.DOCUMENT && file.ext === '.pdf') {
|
||||
// 检查文件大小限制
|
||||
if (file.size > fileSizeLimit) {
|
||||
// 如果支持大文件上传(如Gemini File API),尝试上传
|
||||
if (supportsLargeFileUpload(model)) {
|
||||
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
|
||||
const uploadResult = await handleLargeFileUpload(file, model)
|
||||
if (uploadResult) {
|
||||
return uploadResult
|
||||
}
|
||||
// 如果上传失败,回退到文本处理
|
||||
logger.warn(`Failed to upload large PDF ${file.origin_name}, falling back to text extraction`)
|
||||
window.toast.warning(i18n.t('message.warning.file.pdf_upload_failed', { name: file.origin_name }))
|
||||
return null
|
||||
} else {
|
||||
logger.warn(`PDF file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`)
|
||||
window.toast.warning(
|
||||
i18n.t('message.warning.file.pdf_exceeds_limit', {
|
||||
name: file.origin_name,
|
||||
limit: `${Math.round(fileSizeLimit / 1024 / 1024)}MB`
|
||||
})
|
||||
)
|
||||
return null // 文件过大,回退到文本处理
|
||||
}
|
||||
}
|
||||
|
||||
const base64Data = await window.api.file.base64File(file.id + file.ext)
|
||||
|
||||
return {
|
||||
type: 'file',
|
||||
data: base64Data.data,
|
||||
mediaType: base64Data.mime,
|
||||
filename: file.origin_name
|
||||
}
|
||||
}
|
||||
|
||||
// 处理图片文件
|
||||
if (file.type === FILE_TYPE.IMAGE && supportsImageInput(model)) {
|
||||
// 检查文件大小
|
||||
if (file.size > fileSizeLimit) {
|
||||
logger.warn(`Image file ${file.origin_name} exceeds size limit (${file.size} > ${fileSizeLimit})`)
|
||||
return null
|
||||
}
|
||||
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
|
||||
// 处理MIME类型,特别是jpg->jpeg的转换(Anthropic要求)
|
||||
let mediaType = base64Data.mime
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
if (aiSdkId === 'anthropic' && mediaType === 'image/jpg') {
|
||||
mediaType = 'image/jpeg'
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'file',
|
||||
data: base64Data.base64,
|
||||
mediaType: mediaType,
|
||||
filename: file.origin_name
|
||||
}
|
||||
}
|
||||
|
||||
// 处理其他文档类型(Word、Excel等)
|
||||
if (file.type === FILE_TYPE.DOCUMENT && file.ext !== '.pdf') {
|
||||
// 目前大多数提供商不支持Word等格式的原生处理
|
||||
// 返回null会触发上层调用convertFileBlockToTextPart进行文本提取
|
||||
// 这与Legacy架构中的处理方式一致
|
||||
logger.debug(`Document file ${file.origin_name} with extension ${file.ext} will use text extraction fallback`)
|
||||
return null
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to process file ${file.origin_name}:`, error as Error)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, Model } from '@renderer/types'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
|
||||
|
||||
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
|
||||
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
|
||||
// const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
|
||||
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
|
||||
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
|
||||
|
||||
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
|
||||
const anthropicHeaders: string[] = []
|
||||
const provider = getProviderByModel(model)
|
||||
if (
|
||||
isClaude45ReasoningModel(model) &&
|
||||
isToolUseModeFunction(assistant) &&
|
||||
!(isVertexProvider(provider) || isAwsBedrockProvider(provider))
|
||||
) {
|
||||
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
||||
}
|
||||
if (isClaude4SeriesModel(model)) {
|
||||
if (isVertexProvider(provider) && assistant.enableWebSearch) {
|
||||
anthropicHeaders.push(WEBSEARCH_HEADER)
|
||||
}
|
||||
// We may add it by user preference in assistant.settings instead of always adding it.
|
||||
// See #11540, #11397
|
||||
// anthropicHeaders.push(CONTEXT_100M_HEADER)
|
||||
}
|
||||
return anthropicHeaders
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
/**
|
||||
* AI SDK 参数转换模块 - 统一入口
|
||||
*
|
||||
* 此模块已重构,功能分拆到以下子模块:
|
||||
* - modelParameters.ts: 基础参数处理 (温度、TopP、超时)
|
||||
* - modelCapabilities.ts: 模型能力检查 (PDF、图片、文件支持)
|
||||
* - fileProcessor.ts: 文件处理逻辑 (转换、上传)
|
||||
* - messageConverter.ts: 消息转换核心 (单个消息转换)
|
||||
* - parameterBuilder.ts: 参数构建器 (最终参数组装)
|
||||
*/
|
||||
|
||||
// 基础参数处理
|
||||
export { getTimeout } from './modelParameters'
|
||||
|
||||
// 文件处理
|
||||
export { extractFileContent } from './fileProcessor'
|
||||
|
||||
// 消息转换
|
||||
export { convertMessagesToSdkMessages, convertMessageToSdkParam } from './messageConverter'
|
||||
|
||||
// 参数构建 (主要API)
|
||||
export { buildGenerateTextParams, buildStreamTextParams } from './parameterBuilder'
|
||||
@@ -1,387 +0,0 @@
|
||||
/**
|
||||
* 消息转换模块
|
||||
* 将 Cherry Studio 消息格式转换为 AI SDK 消息格式
|
||||
*/
|
||||
|
||||
import type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
import { loggerService } from '@logger'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import type { Message, Model } from '@renderer/types'
|
||||
import type {
|
||||
FileMessageBlock,
|
||||
ImageMessageBlock,
|
||||
MainTextMessageBlock,
|
||||
ThinkingMessageBlock
|
||||
} from '@renderer/types/newMessage'
|
||||
import {
|
||||
findFileBlocks,
|
||||
findImageBlocks,
|
||||
findMainTextBlocks,
|
||||
findThinkingBlocks,
|
||||
getMainTextContent
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import { parseDataUrl } from '@shared/utils'
|
||||
import type {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
ModelMessage,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
import i18n from 'i18next'
|
||||
|
||||
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
|
||||
|
||||
const logger = loggerService.withContext('messageConverter')
|
||||
|
||||
/**
|
||||
* 转换消息为 AI SDK 参数格式
|
||||
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
|
||||
*/
|
||||
export async function convertMessageToSdkParam(
|
||||
message: Message,
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<ModelMessage | ModelMessage[]> {
|
||||
const content = getMainTextContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
const reasoningBlocks = findThinkingBlocks(message)
|
||||
const mainTextBlocks = findMainTextBlocks(message)
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model)
|
||||
} else {
|
||||
return convertMessageToAssistantModelMessage(
|
||||
content,
|
||||
fileBlocks,
|
||||
imageBlocks,
|
||||
reasoningBlocks,
|
||||
mainTextBlocks,
|
||||
model
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async function convertImageBlockToImagePart(imageBlocks: ImageMessageBlock[]): Promise<Array<ImagePart>> {
|
||||
const parts: Array<ImagePart> = []
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
try {
|
||||
const ext = imageBlock.file.ext.startsWith('.') ? imageBlock.file.ext : `.${imageBlock.file.ext}`
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: image.base64,
|
||||
mediaType: image.mime
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to load image file, image will be excluded from message:', {
|
||||
fileId: imageBlock.file.id,
|
||||
fileName: imageBlock.file.origin_name,
|
||||
error: error as Error
|
||||
})
|
||||
}
|
||||
} else if (imageBlock.url) {
|
||||
const url = imageBlock.url
|
||||
const parseResult = parseDataUrl(url)
|
||||
if (parseResult?.isBase64) {
|
||||
const { mediaType, data } = parseResult
|
||||
parts.push({ type: 'image', image: data, ...(mediaType ? { mediaType } : {}) })
|
||||
} else if (url.startsWith('data:')) {
|
||||
// Malformed data URL or non-base64 data URL
|
||||
logger.error('Malformed or non-base64 data URL detected, image will be excluded:', {
|
||||
urlPrefix: url.slice(0, 50) + '...'
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
// For remote URLs we keep payload minimal to match existing expectations.
|
||||
parts.push({ type: 'image', image: url })
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换为用户模型消息
|
||||
*/
|
||||
async function convertMessageToUserModelMessage(
|
||||
content: string,
|
||||
fileBlocks: FileMessageBlock[],
|
||||
imageBlocks: ImageMessageBlock[],
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
|
||||
const parts: Array<TextPart | FilePart | ImagePart> = []
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
}
|
||||
|
||||
// 处理图片(仅在支持视觉的模型中)
|
||||
if (isVisionModel) {
|
||||
parts.push(...(await convertImageBlockToImagePart(imageBlocks)))
|
||||
}
|
||||
// 处理文件
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
let processed = false
|
||||
|
||||
// 优先尝试原生文件支持(PDF、图片等)
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
// 判断filePart是否为string
|
||||
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
|
||||
return [
|
||||
{
|
||||
role: 'system',
|
||||
content: filePart.data
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: parts.length > 0 ? parts : ''
|
||||
}
|
||||
]
|
||||
}
|
||||
parts.push(filePart)
|
||||
logger.debug(`File ${file.origin_name} processed as native file format`)
|
||||
processed = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果原生处理失败,回退到文本提取
|
||||
if (!processed) {
|
||||
const textPart = await convertFileBlockToTextPart(fileBlock)
|
||||
if (textPart) {
|
||||
parts.push(textPart)
|
||||
logger.debug(`File ${file.origin_name} processed as text content`)
|
||||
} else {
|
||||
logger.warn(`File ${file.origin_name} could not be processed in any format`)
|
||||
window.toast.error(i18n.t('message.error.file.process_failed', { name: file.origin_name }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: 'user',
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces markdown images with data URI sources (e.g. ``)
|
||||
* with a placeholder `` to avoid sending huge base64 payloads to the API.
|
||||
*
|
||||
* Uses string scanning (indexOf) instead of regex to avoid OOM on multi-MB base64 strings.
|
||||
*/
|
||||
export function stripMarkdownBase64Images(text: string): string {
|
||||
const marker = '](data:'
|
||||
let result = ''
|
||||
let searchFrom = 0
|
||||
|
||||
while (searchFrom < text.length) {
|
||||
const markerIdx = text.indexOf(marker, searchFrom)
|
||||
if (markerIdx === -1) {
|
||||
result += text.slice(searchFrom)
|
||||
break
|
||||
}
|
||||
|
||||
// Find the `
|
||||
if (bangIdx === -1 || text.indexOf(']', bangIdx + 2) !== markerIdx) {
|
||||
// Not a valid markdown image — skip past this marker
|
||||
result += text.slice(searchFrom, markerIdx + marker.length)
|
||||
searchFrom = markerIdx + marker.length
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the closing `)` — the URL part starts after `](`
|
||||
const urlStart = markerIdx + 2 // position right after `](`
|
||||
const closeIdx = text.indexOf(')', urlStart)
|
||||
if (closeIdx === -1) {
|
||||
result += text.slice(searchFrom)
|
||||
break
|
||||
}
|
||||
|
||||
// Extract alt text between `![` and `]`
|
||||
const altText = text.slice(bangIdx + 2, markerIdx)
|
||||
|
||||
// Append everything before ``
|
||||
searchFrom = closeIdx + 1
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换为助手模型消息
|
||||
* 注意:当助手消息只包含图片(如图片生成模型的响应)而没有文本时,
|
||||
* 需要添加占位文本,因为某些 API(如 Gemini)不接受空的 assistant 消息
|
||||
*/
|
||||
async function convertMessageToAssistantModelMessage(
|
||||
content: string,
|
||||
fileBlocks: FileMessageBlock[],
|
||||
imageBlocks: ImageMessageBlock[],
|
||||
thinkingBlocks: ThinkingMessageBlock[],
|
||||
mainTextBlocks: MainTextMessageBlock[],
|
||||
model?: Model
|
||||
): Promise<AssistantModelMessage> {
|
||||
const parts: Array<TextPart | ReasoningPart | FilePart> = []
|
||||
|
||||
// Add reasoning blocks first (required by AWS Bedrock for Claude extended thinking)
|
||||
for (const thinkingBlock of thinkingBlocks) {
|
||||
parts.push({ type: 'reasoning', text: thinkingBlock.content })
|
||||
}
|
||||
|
||||
// Add text content after reasoning blocks, only if non-empty after trimming
|
||||
// Also add thoughtSignature from MainTextBlock metadata for Gemini thought signature persistence
|
||||
// Strip inline base64 data URIs from markdown images to prevent HTTP 413 errors (#12602)
|
||||
// Uses string scanning instead of regex to avoid OOM on large base64 payloads
|
||||
const trimmedContent = stripMarkdownBase64Images(content?.trim() ?? '')
|
||||
if (trimmedContent) {
|
||||
// Find the first MainTextBlock with thoughtSignature
|
||||
const thoughtSignature = mainTextBlocks.find((block) => block.metadata?.thoughtSignature)?.metadata
|
||||
?.thoughtSignature
|
||||
|
||||
const textPart: TextPart = { type: 'text', text: trimmedContent }
|
||||
|
||||
// Add providerOptions with thoughtSignature if available (for Gemini)
|
||||
if (thoughtSignature) {
|
||||
textPart.providerOptions = {
|
||||
google: {
|
||||
thoughtSignature
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parts.push(textPart)
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
// 优先尝试原生文件支持(PDF等)
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
parts.push(filePart)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到文本处理
|
||||
const textPart = await convertFileBlockToTextPart(fileBlock)
|
||||
if (textPart) {
|
||||
parts.push(textPart)
|
||||
}
|
||||
}
|
||||
|
||||
// 当 parts 为空但有图片时,添加占位文本
|
||||
// 这对于图片生成模型的继续对话很重要,因为助手消息可能只包含生成的图片
|
||||
if (parts.length === 0 && imageBlocks.length > 0) {
|
||||
parts.push({ type: 'text', text: '[Image]' })
|
||||
}
|
||||
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an array of messages to SDK-compatible model messages.
|
||||
*
|
||||
* This function processes messages and transforms them into the format required by the SDK.
|
||||
* It handles special cases for vision models and image enhancement models.
|
||||
*
|
||||
* @param messages - Array of messages to convert.
|
||||
* @param model - The model configuration that determines conversion behavior
|
||||
*
|
||||
* @returns A promise that resolves to an array of SDK-compatible model messages
|
||||
*
|
||||
* @remarks
|
||||
* For image enhancement models:
|
||||
* - Collapses the conversation into [system?, user(image)] format
|
||||
* - Searches backwards through all messages to find the most recent assistant message with images
|
||||
* - Preserves all system messages (including ones generated from file uploads like 'fileid://...')
|
||||
* - Extracts the last user message content and merges images from the previous assistant message
|
||||
* - Returns only the collapsed messages: system messages (if any) followed by a single user message
|
||||
* - If no user message is found, returns only system messages
|
||||
* - Typical pattern: [system?, user, assistant(image), user] -> [system?, user(image)]
|
||||
*
|
||||
* For other models:
|
||||
* - Returns all converted messages in order without special image handling
|
||||
*
|
||||
* The function automatically detects vision model capabilities and adjusts conversion accordingly.
|
||||
*/
|
||||
export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise<ModelMessage[]> {
|
||||
const sdkMessages: ModelMessage[] = []
|
||||
const isVision = isVisionModel(model)
|
||||
|
||||
for (const message of messages) {
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
|
||||
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
|
||||
}
|
||||
// Special handling for vison models
|
||||
// These models support multi-turn conversations but need images from previous assistant messages
|
||||
// to be merged into the current user message for editing/enhancement operations.
|
||||
//
|
||||
// Key behaviors:
|
||||
// 1. Preserve all conversation history for context
|
||||
// 2. Find images from the previous assistant message and merge them into the last user message
|
||||
// 3. This allows users to switch from LLM conversations and use that context for image generation
|
||||
if (isVision) {
|
||||
// Find the last user SDK message index
|
||||
const lastUserSdkIndex = (() => {
|
||||
for (let i = sdkMessages.length - 1; i >= 0; i--) {
|
||||
if (sdkMessages[i].role === 'user') return i
|
||||
}
|
||||
return -1
|
||||
})()
|
||||
|
||||
// If no user message found, return messages as-is
|
||||
if (lastUserSdkIndex < 0) {
|
||||
return sdkMessages
|
||||
}
|
||||
|
||||
// Find the nearest preceding assistant message in original messages
|
||||
let prevAssistant: Message | null = null
|
||||
for (let i = messages.length - 2; i >= 0; i--) {
|
||||
if (messages[i].role === 'assistant') {
|
||||
prevAssistant = messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there are images from the previous assistant message
|
||||
const imageBlocks = prevAssistant ? findImageBlocks(prevAssistant) : []
|
||||
const imageParts = await convertImageBlockToImagePart(imageBlocks)
|
||||
|
||||
// If no images to merge, return messages as-is
|
||||
if (imageParts.length === 0) {
|
||||
return sdkMessages
|
||||
}
|
||||
|
||||
// Build the new last user message with merged images
|
||||
const lastUserSdk = sdkMessages[lastUserSdkIndex] as UserModelMessage
|
||||
let finalUserParts: Array<TextPart | FilePart | ImagePart> = []
|
||||
|
||||
if (typeof lastUserSdk.content === 'string') {
|
||||
finalUserParts.push({ type: 'text', text: lastUserSdk.content })
|
||||
} else if (Array.isArray(lastUserSdk.content)) {
|
||||
finalUserParts = [...lastUserSdk.content]
|
||||
}
|
||||
|
||||
// Append images from the previous assistant message
|
||||
finalUserParts.push(...imageParts)
|
||||
|
||||
// Replace the last user message with the merged version
|
||||
const result = [...sdkMessages]
|
||||
result[lastUserSdkIndex] = { role: 'user', content: finalUserParts }
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
/**
|
||||
* 模型能力检查模块
|
||||
* 检查不同模型支持的功能(PDF输入、图片输入、大文件上传等)
|
||||
*/
|
||||
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { FileType, Model } from '@renderer/types'
|
||||
import { FILE_TYPE } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
|
||||
// 工具函数:基于模型名和提供商判断是否支持某特性
|
||||
function modelSupportValidator(
|
||||
model: Model,
|
||||
{
|
||||
supportedModels = [],
|
||||
unsupportedModels = [],
|
||||
supportedProviders = [],
|
||||
unsupportedProviders = []
|
||||
}: {
|
||||
supportedModels?: string[]
|
||||
unsupportedModels?: string[]
|
||||
supportedProviders?: string[]
|
||||
unsupportedProviders?: string[]
|
||||
}
|
||||
): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// 黑名单:命中不支持的模型直接拒绝
|
||||
if (unsupportedModels.some((name) => model.name.includes(name))) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
|
||||
if (unsupportedProviders.includes(aiSdkId)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 白名单:命中支持的模型名
|
||||
if (supportedModels.some((name) => model.name.includes(name))) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 回退到提供商判断
|
||||
return supportedProviders.includes(aiSdkId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持原生图片输入
|
||||
*/
|
||||
export function supportsImageInput(model: Model): boolean {
|
||||
return isVisionModel(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查提供商是否支持大文件上传(如Gemini File API)
|
||||
*/
|
||||
export function supportsLargeFileUpload(model: Model): boolean {
|
||||
// 基于AI SDK文档,以下模型或提供商支持大文件上传
|
||||
return modelSupportValidator(model, {
|
||||
supportedModels: ['qwen-long', 'qwen-doc'],
|
||||
supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取提供商特定的文件大小限制
|
||||
*/
|
||||
export function getFileSizeLimit(model: Model, fileType: FileType): number {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// Anthropic PDF限制32MB
|
||||
if (aiSdkId === 'anthropic' && fileType === FILE_TYPE.DOCUMENT) {
|
||||
return 32 * 1024 * 1024 // 32MB
|
||||
}
|
||||
|
||||
// Gemini小文件限制20MB(超过此限制会使用File API上传)
|
||||
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
|
||||
return 20 * 1024 * 1024 // 20MB
|
||||
}
|
||||
|
||||
// Dashscope如果模型支持大文件上传优先使用File API上传
|
||||
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
|
||||
return 0 // 使用较小的默认值
|
||||
}
|
||||
|
||||
// 其他提供商没有明确限制,使用较大的默认值
|
||||
// 这与Legacy架构中的实现一致,让提供商自行处理文件大小
|
||||
return Infinity
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
/**
|
||||
* 模型基础参数处理模块
|
||||
* 处理温度、TopP、超时等基础参数的获取逻辑
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
isClaude46SeriesModel,
|
||||
isClaudeReasoningModel,
|
||||
isMaxTemperatureOneModel,
|
||||
isSupportedFlexServiceTier,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportTemperatureModel,
|
||||
isSupportTopPModel,
|
||||
isTemperatureTopPMutuallyExclusiveModel
|
||||
} from '@renderer/config/models'
|
||||
import {
|
||||
DEFAULT_ASSISTANT_SETTINGS,
|
||||
getAssistantSettings,
|
||||
getProviderByModel
|
||||
} from '@renderer/services/AssistantService'
|
||||
import { type Assistant, type Model } from '@renderer/types'
|
||||
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
|
||||
|
||||
import { getThinkingBudget } from '../utils/reasoning'
|
||||
|
||||
const logger = loggerService.withContext('modelParameters')
|
||||
|
||||
/**
|
||||
* Retrieves the temperature parameter, adapting it based on assistant.settings and model capabilities.
|
||||
* - Disabled when enableTemperature is off.
|
||||
* - Disabled for Claude reasoning models when reasoning effort is set (excluding 'default' and 'none').
|
||||
* - Disabled for models that do not support temperature.
|
||||
* - Clamped to 1 for models with max temperature of 1.
|
||||
* Otherwise, returns the temperature value.
|
||||
*/
|
||||
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
const enableTemperature = assistant.settings?.enableTemperature ?? DEFAULT_ASSISTANT_SETTINGS.enableTemperature
|
||||
if (!enableTemperature) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Thinking isn't compatible with temperature or top_k modifications as well as forced tool use.
|
||||
// See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking#feature-compatibility
|
||||
if (
|
||||
isClaudeReasoningModel(model) &&
|
||||
assistant.settings?.reasoning_effort &&
|
||||
assistant.settings.reasoning_effort !== 'default' &&
|
||||
assistant.settings.reasoning_effort !== 'none'
|
||||
) {
|
||||
logger.info(`Model ${model.id} does not support reasoning with temperature, disabling temperature`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (!isSupportTemperatureModel(model, assistant)) {
|
||||
logger.info(`Model ${model.id} does not support temperature, disabling temperature`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
let temperature = assistant.settings?.temperature ?? DEFAULT_ASSISTANT_SETTINGS.temperature
|
||||
|
||||
if (isMaxTemperatureOneModel(model) && temperature > 1) {
|
||||
logger.info(`Model ${model.id} has max temperature of 1, clamping temperature from ${temperature} to 1`)
|
||||
temperature = 1
|
||||
}
|
||||
|
||||
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTopP) {
|
||||
logger.info(`Model ${model.id} only accepts one of temperature and topP, both enabled; keeping temperature`)
|
||||
}
|
||||
|
||||
return temperature
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the TopP parameter, adapting it based on assistant.settings and model capabilities.
|
||||
* - Disabled when enableTopP is off.
|
||||
* - Disabled for models that do not support TopP.
|
||||
* - Disabled for mutually exclusive models when temperature is enabled.
|
||||
* - Clamped to [0.95, 1] for Claude reasoning models with reasoning effort set (excluding 'default' and 'none').
|
||||
* Otherwise, returns the TopP value.
|
||||
*/
|
||||
export function getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
const enableTopP = assistant.settings?.enableTopP ?? DEFAULT_ASSISTANT_SETTINGS.enableTopP
|
||||
if (!enableTopP) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (!isSupportTopPModel(model, assistant)) {
|
||||
logger.info(`Model ${model.id} does not support topP, disabling topP.`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (isTemperatureTopPMutuallyExclusiveModel(model) && assistant.settings?.enableTemperature) {
|
||||
logger.info(`Model ${model.id} only accepts one of temperature and topP, disabling topP.`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
let topP = assistant.settings?.topP ?? DEFAULT_ASSISTANT_SETTINGS.topP
|
||||
|
||||
// When thinking is enabled, the topP should be between 0.95 and 1
|
||||
// See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking#feature-compatibility
|
||||
// NOTE: It depends on the behavior that extended thinking defaults to off, so we clamp the topP value also when reasoning is not 'default'
|
||||
if (
|
||||
isClaudeReasoningModel(model) &&
|
||||
assistant.settings?.reasoning_effort &&
|
||||
assistant.settings.reasoning_effort !== 'default' &&
|
||||
assistant.settings.reasoning_effort !== 'none'
|
||||
) {
|
||||
const clampedTopP = Math.max(0.95, Math.min(topP, 1))
|
||||
if (clampedTopP !== topP) {
|
||||
logger.info(`Claude Model ${model.id} has reasoning enabled, clamping topP from ${topP} to ${clampedTopP}`)
|
||||
}
|
||||
topP = clampedTopP
|
||||
}
|
||||
|
||||
return topP
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取超时设置
|
||||
*/
|
||||
export function getTimeout(model: Model): number {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return DEFAULT_TIMEOUT
|
||||
}
|
||||
|
||||
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
|
||||
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
const enabledMaxTokens = assistantSettings.enableMaxTokens ?? false
|
||||
let maxTokens = assistantSettings.maxTokens
|
||||
|
||||
// If user hasn't enabled enableMaxTokens, return undefined to let the API use its default value.
|
||||
// Note: Anthropic API requires max_tokens, but that's handled by the Anthropic client with a fallback.
|
||||
if (!enabledMaxTokens || maxTokens === undefined) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
// Claude 4.6 uses adaptive thinking (no budgetTokens), so the AI SDK does not add budget back
|
||||
// to maxOutputTokens. Skip the subtraction to avoid incorrectly reducing max_tokens.
|
||||
if (
|
||||
isSupportedThinkingTokenClaudeModel(model) &&
|
||||
!isClaude46SeriesModel(model) &&
|
||||
['anthropic', 'aws-bedrock'].includes(provider.type)
|
||||
) {
|
||||
const { reasoning_effort: reasoningEffort } = assistantSettings
|
||||
const budget = getThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
if (budget) {
|
||||
maxTokens -= budget
|
||||
}
|
||||
}
|
||||
return maxTokens
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
/**
|
||||
* 参数构建模块
|
||||
* 构建AI SDK的流式和非流式参数
|
||||
*/
|
||||
|
||||
import { combineHeaders } from '@ai-sdk/provider-utils'
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { extensionRegistry } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import type { AppProviderId } from '@renderer/aiCore/types'
|
||||
import { MAX_TOOL_CALLS, MIN_TOOL_CALLS } from '@renderer/config/constant'
|
||||
import {
|
||||
isAnthropicModel,
|
||||
isFixedReasoningModel,
|
||||
isGeminiModel,
|
||||
isGenerateImageModel,
|
||||
isGrokModel,
|
||||
isOpenAIModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isPureGenerateImageModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getHubModeSystemPrompt } from '@renderer/config/prompts-code-mode'
|
||||
import { DEFAULT_ASSISTANT_SETTINGS, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { IdleTimeoutController, type IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
|
||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { isAIGatewayProvider, isAwsBedrockProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
|
||||
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import { stepCountIs } from 'ai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import type { ProviderCapabilities } from '../types'
|
||||
import { setupToolsConfig } from '../utils/mcp'
|
||||
import { buildProviderOptions } from '../utils/options'
|
||||
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
|
||||
import { addAnthropicHeaders } from './header'
|
||||
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
|
||||
|
||||
const logger = loggerService.withContext('parameterBuilder')
|
||||
|
||||
/**
|
||||
* Validates and clamps maxToolCalls to valid range
|
||||
* Falls back to DEFAULT_ASSISTANT_SETTINGS.maxToolCalls if invalid
|
||||
* @param value - The maxToolCalls value from settings
|
||||
* @returns Validated maxToolCalls value
|
||||
*/
|
||||
function validateMaxToolCalls(value: number | undefined): number {
|
||||
if (value === undefined || value < MIN_TOOL_CALLS || value > MAX_TOOL_CALLS) {
|
||||
return DEFAULT_ASSISTANT_SETTINGS.maxToolCalls
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
function mapVertexAIGatewayModelToProviderId(model: Model): AppProviderId | undefined {
|
||||
if (isAnthropicModel(model)) {
|
||||
return 'anthropic'
|
||||
}
|
||||
if (isGeminiModel(model)) {
|
||||
return 'google'
|
||||
}
|
||||
if (isGrokModel(model)) {
|
||||
return 'xai'
|
||||
}
|
||||
if (isOpenAIModel(model)) {
|
||||
return 'openai'
|
||||
}
|
||||
logger.warn(`Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 流式参数
|
||||
* 这是主要的参数构建函数,整合所有转换逻辑
|
||||
*/
|
||||
export async function buildStreamTextParams(
|
||||
sdkMessages: StreamTextParams['messages'] = [],
|
||||
assistant: Assistant,
|
||||
provider: Provider,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
allowedTools?: string[]
|
||||
webSearchProviderId?: string
|
||||
webSearchConfig?: CherryWebSearchConfig
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string | undefined>
|
||||
}
|
||||
}
|
||||
): Promise<{
|
||||
params: StreamTextParams
|
||||
modelId: string
|
||||
capabilities: ProviderCapabilities
|
||||
webSearchPluginConfig?: WebSearchPluginConfig
|
||||
idleTimeout: IdleTimeoutHandle
|
||||
}> {
|
||||
const { mcpTools, requestOptions = {} } = options
|
||||
// No caller currently provides a custom timeout; defaultTimeout (10 min) is the fallback.
|
||||
const { signal: externalSignal, timeout = DEFAULT_TIMEOUT, headers: inputHeaders = {} } = requestOptions
|
||||
|
||||
// Use an idle timeout that resets every time a stream chunk is received,
|
||||
// instead of a fixed total timeout that starts from the initial request.
|
||||
const idleTimeout = new IdleTimeoutController(timeout)
|
||||
const signals = [idleTimeout.signal]
|
||||
if (externalSignal) {
|
||||
signals.push(externalSignal)
|
||||
}
|
||||
const finalSignal = AbortSignal.any(signals)
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
// 这三个变量透传出来,交给下面启用插件/中间件
|
||||
// 也可以在外部构建好再传入buildStreamTextParams
|
||||
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
isFixedReasoningModel(model)
|
||||
|
||||
// 判断是否使用内置搜索
|
||||
// 条件:没有外部搜索提供商 && (用户开启了内置搜索 || 模型强制使用内置搜索)
|
||||
const hasExternalSearch = !!options.webSearchProviderId
|
||||
const enableWebSearch =
|
||||
!hasExternalSearch &&
|
||||
((assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar'))
|
||||
|
||||
// Validate provider and model support to prevent stale state from triggering urlContext
|
||||
const enableUrlContext = !!(
|
||||
assistant.enableUrlContext &&
|
||||
isSupportUrlContextProvider(provider) &&
|
||||
!isPureGenerateImageModel(model) &&
|
||||
(isGeminiModel(model) || isAnthropicModel(model))
|
||||
)
|
||||
|
||||
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
||||
|
||||
const tools = setupToolsConfig(mcpTools, options.allowedTools)
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const webSearchConfig: CherryWebSearchConfig = {
|
||||
maxResults: store.getState().websearch.maxResults,
|
||||
excludeDomains: store.getState().websearch.excludeDomains,
|
||||
searchWithTime: store.getState().websearch.searchWithTime
|
||||
}
|
||||
|
||||
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
})
|
||||
|
||||
// Web search + URL context 的工具注入由 plugin 系统处理:
|
||||
// - webSearchPlugin: 根据 provider 的 toolFactories.webSearch 自动注入
|
||||
// - urlContextPlugin: 根据 provider 的 toolFactories.urlContext 自动注入
|
||||
// parameterBuilder 只构建 config,传给 plugin
|
||||
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
|
||||
if (enableWebSearch) {
|
||||
if (extensionRegistry.has(aiSdkProviderId)) {
|
||||
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
|
||||
} else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) {
|
||||
const gatewayProviderId = mapVertexAIGatewayModelToProviderId(model)
|
||||
if (gatewayProviderId) {
|
||||
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(gatewayProviderId, webSearchConfig, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let headers = inputHeaders
|
||||
|
||||
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
|
||||
const betaHeaders = addAnthropicHeaders(assistant, model)
|
||||
// Only add the anthropic-beta header if there are actual beta headers to include
|
||||
if (betaHeaders.length > 0) {
|
||||
const newBetaHeaders = { 'anthropic-beta': betaHeaders.join(',') }
|
||||
headers = combineHeaders(headers, newBetaHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建基础参数
|
||||
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
||||
// are extracted from custom parameters and passed directly to streamText()
|
||||
// instead of being placed in providerOptions
|
||||
|
||||
// Get max tool calls from assistant settings
|
||||
// When enabled, validate and use user-defined value (1-100)
|
||||
// When disabled, don't pass stopWhen - let AI SDK use its own default
|
||||
const enableMaxToolCalls = assistant.settings?.enableMaxToolCalls ?? DEFAULT_ASSISTANT_SETTINGS.enableMaxToolCalls
|
||||
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxOutputTokens: getMaxTokens(assistant, model),
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
// Include AI SDK standard params extracted from custom parameters
|
||||
...standardParams,
|
||||
abortSignal: finalSignal,
|
||||
headers,
|
||||
providerOptions,
|
||||
maxRetries: 0
|
||||
}
|
||||
|
||||
// Only add stopWhen when explicitly enabled and validated
|
||||
if (enableMaxToolCalls) {
|
||||
const maxToolCalls = validateMaxToolCalls(assistant.settings?.maxToolCalls)
|
||||
params.stopWhen = stepCountIs(maxToolCalls)
|
||||
}
|
||||
// When disabled, don't pass stopWhen - let AI SDK use its own default
|
||||
|
||||
if (tools) {
|
||||
params.tools = tools
|
||||
}
|
||||
|
||||
let systemPrompt = assistant.prompt ? await replacePromptVariables(assistant.prompt, model.name) : ''
|
||||
|
||||
if (getEffectiveMcpMode(assistant) === 'auto') {
|
||||
const autoModePrompt = getHubModeSystemPrompt()
|
||||
if (autoModePrompt) {
|
||||
systemPrompt = systemPrompt ? `${systemPrompt}\n\n${autoModePrompt}` : autoModePrompt
|
||||
}
|
||||
}
|
||||
|
||||
if (systemPrompt) {
|
||||
params.system = systemPrompt
|
||||
}
|
||||
|
||||
logger.debug('params', params)
|
||||
|
||||
return {
|
||||
params,
|
||||
modelId: model.id,
|
||||
capabilities: { enableReasoning, enableWebSearch, enableGenerateImage, enableUrlContext },
|
||||
webSearchPluginConfig,
|
||||
idleTimeout
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建非流式的 generateText 参数
|
||||
*/
|
||||
export async function buildGenerateTextParams(
|
||||
messages: ModelMessage[],
|
||||
assistant: Assistant,
|
||||
provider: Provider,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
allowedTools?: string[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
): Promise<any> {
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, provider, options)
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getAiSdkProviderId } from '../factory'
|
||||
|
||||
// Mock the external dependencies
|
||||
vi.mock('@cherrystudio/ai-core', () => ({
|
||||
registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers
|
||||
getProviderMapping: vi.fn((id: string) => {
|
||||
// Mock dynamic mappings
|
||||
const mappings: Record<string, string> = {
|
||||
openrouter: 'openrouter',
|
||||
'google-vertex': 'google-vertex',
|
||||
vertexai: 'google-vertex',
|
||||
bedrock: 'bedrock',
|
||||
'aws-bedrock': 'bedrock',
|
||||
zhipu: 'zhipu'
|
||||
}
|
||||
return mappings[id]
|
||||
}),
|
||||
AiCore: {
|
||||
isSupported: vi.fn(() => true)
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the provider configs
|
||||
vi.mock('../providerConfigs', () => ({
|
||||
initializeNewProviders: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
function createTestProvider(id: string, type: string): Provider {
|
||||
return {
|
||||
id,
|
||||
type,
|
||||
name: `Test ${id}`,
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'test-host'
|
||||
} as Provider
|
||||
}
|
||||
|
||||
function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
|
||||
return {
|
||||
id,
|
||||
type: 'azure-openai',
|
||||
name: `Azure Test ${id}`,
|
||||
apiKey: 'azure-test-key',
|
||||
apiHost: 'azure-test-host',
|
||||
apiVersion,
|
||||
models: [{ id: model || 'gpt-4' } as Model]
|
||||
}
|
||||
}
|
||||
|
||||
describe('Integrated Provider Registry', () => {
|
||||
describe('Provider ID Resolution', () => {
|
||||
it('should resolve openrouter provider correctly', () => {
|
||||
const provider = createTestProvider('openrouter', 'openrouter')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('openrouter')
|
||||
})
|
||||
|
||||
it('should resolve google-vertex provider correctly', () => {
|
||||
const provider = createTestProvider('google-vertex', 'vertexai')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('google-vertex')
|
||||
})
|
||||
|
||||
it('should resolve bedrock provider correctly', () => {
|
||||
const provider = createTestProvider('bedrock', 'aws-bedrock')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('bedrock')
|
||||
})
|
||||
|
||||
it('should resolve zhipu provider correctly', () => {
|
||||
const provider = createTestProvider('zhipu', 'zhipu')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('zhipu')
|
||||
})
|
||||
|
||||
it('should resolve provider type mapping correctly', () => {
|
||||
const provider = createTestProvider('vertex-test', 'vertexai')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('google-vertex')
|
||||
})
|
||||
|
||||
it('should handle static provider mappings', () => {
|
||||
const geminiProvider = createTestProvider('gemini', 'gemini')
|
||||
const result = getAiSdkProviderId(geminiProvider)
|
||||
expect(result).toBe('google')
|
||||
})
|
||||
|
||||
it('should fallback to provider.id for unknown providers', () => {
|
||||
const unknownProvider = createTestProvider('unknown-provider', 'unknown-type')
|
||||
const result = getAiSdkProviderId(unknownProvider)
|
||||
expect(result).toBe('unknown-provider')
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI providers correctly', () => {
|
||||
const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
|
||||
const result = getAiSdkProviderId(azureProvider)
|
||||
expect(result).toBe('azure')
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI providers response endpoint correctly', () => {
|
||||
const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
|
||||
const result = getAiSdkProviderId(azureProvider)
|
||||
expect(result).toBe('azure-responses')
|
||||
})
|
||||
|
||||
it('should handle Azure provider Claude Models', () => {
|
||||
const provider = createTestProvider('azure-anthropic', 'anthropic')
|
||||
const result = getAiSdkProviderId(provider)
|
||||
expect(result).toBe('azure-anthropic')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility', () => {
|
||||
it('should maintain compatibility with existing providers', () => {
|
||||
const grokProvider = createTestProvider('grok', 'grok')
|
||||
const result = getAiSdkProviderId(grokProvider)
|
||||
expect(result).toBe('xai-responses')
|
||||
})
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +0,0 @@
|
||||
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
|
||||
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
|
||||
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
|
||||
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
|
||||
|
||||
export const COPILOT_DEFAULT_HEADERS = {
|
||||
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
|
||||
'User-Agent': COPILOT_USER_AGENT,
|
||||
'Editor-Version': COPILOT_EDITOR_VERSION,
|
||||
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
|
||||
'editor-version': COPILOT_EDITOR_VERSION,
|
||||
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
|
||||
'copilot-vision-request': 'true'
|
||||
} as const
|
||||
@@ -1,157 +0,0 @@
|
||||
/**
|
||||
* AiHubMix Provider
|
||||
*
|
||||
* Multi-backend API gateway that routes models by model ID prefix:
|
||||
* - claude* -> Anthropic SDK
|
||||
* - gemini* -> Google SDK
|
||||
* - others -> OpenAI Responses SDK (default)
|
||||
*
|
||||
* All requests include the APP-Code header.
|
||||
*/
|
||||
import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal'
|
||||
import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'
|
||||
import { OpenAIChatLanguageModel, OpenAIResponsesLanguageModel, OpenAISpeechModel } from '@ai-sdk/openai/internal'
|
||||
import {
|
||||
OpenAICompatibleChatLanguageModel,
|
||||
OpenAICompatibleEmbeddingModel,
|
||||
OpenAICompatibleImageModel
|
||||
} from '@ai-sdk/openai-compatible'
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { FetchFunction } from '@ai-sdk/provider-utils'
|
||||
import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
import { isOpenAIChatCompletionOnlyModel, isOpenAILLMModel } from '@renderer/config/models/openai'
|
||||
import type { Model } from '@renderer/types'
|
||||
|
||||
export const AIHUBMIX_PROVIDER_NAME = 'aihubmix' as const
|
||||
const APP_CODE_HEADER = { 'APP-Code': 'MLTG2087' }
|
||||
|
||||
export interface AihubmixProviderSettings {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
headers?: Record<string, string>
|
||||
fetch?: FetchFunction
|
||||
}
|
||||
|
||||
export interface AihubmixProvider extends ProviderV3 {
|
||||
(modelId: string): LanguageModelV3
|
||||
languageModel(modelId: string): LanguageModelV3
|
||||
embeddingModel(modelId: string): EmbeddingModelV3
|
||||
imageModel(modelId: string): ImageModelV3
|
||||
}
|
||||
|
||||
export function createAihubmix(options: AihubmixProviderSettings = {}): AihubmixProvider {
|
||||
const { baseURL = 'https://aihubmix.com/v1', fetch: customFetch } = options
|
||||
|
||||
const resolveApiKey = () =>
|
||||
loadApiKey({ apiKey: options.apiKey, environmentVariableName: 'AIHUBMIX_API_KEY', description: 'AiHubMix' })
|
||||
|
||||
const authHeaders = (): Record<string, string> => ({
|
||||
Authorization: `Bearer ${resolveApiKey()}`,
|
||||
'Content-Type': 'application/json',
|
||||
...APP_CODE_HEADER,
|
||||
...options.headers
|
||||
})
|
||||
|
||||
const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}`
|
||||
|
||||
const createAnthropicModel = (modelId: string) => {
|
||||
const headers = authHeaders()
|
||||
return new AnthropicMessagesLanguageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.anthropic`,
|
||||
baseURL,
|
||||
headers: () => ({ ...headers, 'x-api-key': resolveApiKey() }),
|
||||
fetch: customFetch,
|
||||
supportedUrls: () => ({ 'image/*': [/^https?:\/\/.*$/] })
|
||||
})
|
||||
}
|
||||
|
||||
const createGeminiModel = (modelId: string) => {
|
||||
const headers = authHeaders()
|
||||
return new GoogleGenerativeAILanguageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.google`,
|
||||
baseURL: 'https://aihubmix.com/gemini/v1beta',
|
||||
headers: () => ({ ...headers, 'x-goog-api-key': resolveApiKey() }),
|
||||
fetch: customFetch,
|
||||
generateId: () => `${AIHUBMIX_PROVIDER_NAME}-${Date.now()}`,
|
||||
supportedUrls: () => ({})
|
||||
})
|
||||
}
|
||||
|
||||
const createOpenAICompatibleChatModel = (modelId: string): LanguageModelV3 =>
|
||||
new OpenAICompatibleChatLanguageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-compatible-chat`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
const createOpenAIChatModel = (modelId: string): LanguageModelV3 =>
|
||||
new OpenAIChatLanguageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-compatible-chat`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
const createResponsesModel = (modelId: string): LanguageModelV3 =>
|
||||
new OpenAIResponsesLanguageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.openai-response`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch,
|
||||
fileIdPrefixes: ['file-']
|
||||
})
|
||||
|
||||
const createChatModel = (modelId: string): LanguageModelV3 => {
|
||||
if (modelId.startsWith('claude')) {
|
||||
return createAnthropicModel(modelId)
|
||||
}
|
||||
if (
|
||||
(modelId.startsWith('gemini') || modelId.startsWith('imagen')) &&
|
||||
!modelId.endsWith('no-think') &&
|
||||
!modelId.endsWith('-search') &&
|
||||
!modelId.includes('embedding')
|
||||
) {
|
||||
return createGeminiModel(modelId)
|
||||
}
|
||||
const model = { id: modelId } as Model
|
||||
if (isOpenAILLMModel(model)) {
|
||||
if (isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return createOpenAIChatModel(modelId)
|
||||
}
|
||||
return createResponsesModel(modelId)
|
||||
}
|
||||
return createOpenAICompatibleChatModel(modelId)
|
||||
}
|
||||
|
||||
const provider = (modelId: string) => createChatModel(modelId)
|
||||
provider.specificationVersion = 'v3' as const
|
||||
|
||||
provider.languageModel = createChatModel
|
||||
|
||||
provider.embeddingModel = (modelId: string) =>
|
||||
new OpenAICompatibleEmbeddingModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.embedding`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
provider.imageModel = (modelId: string) =>
|
||||
new OpenAICompatibleImageModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.image`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
provider.speechModel = (modelId: string) =>
|
||||
new OpenAISpeechModel(modelId, {
|
||||
provider: `${AIHUBMIX_PROVIDER_NAME}.speech`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
return provider as AihubmixProvider
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
/**
|
||||
* NewAPI Provider
|
||||
*
|
||||
* Multi-backend API gateway (One API / New API) that routes models by endpoint_type:
|
||||
* - anthropic -> Anthropic SDK
|
||||
* - gemini -> Google SDK
|
||||
* - openai-response -> OpenAI Responses SDK
|
||||
* - openai / image-generation -> OpenAI Chat SDK
|
||||
* - fallback -> OpenAI Compatible SDK
|
||||
*
|
||||
* The endpointType is set per-request via provider settings, based on the model's endpoint_type field.
|
||||
*/
|
||||
import { AnthropicMessagesLanguageModel } from '@ai-sdk/anthropic/internal'
|
||||
import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'
|
||||
import { OpenAIResponsesLanguageModel } from '@ai-sdk/openai/internal'
|
||||
import {
|
||||
OpenAICompatibleChatLanguageModel,
|
||||
OpenAICompatibleEmbeddingModel,
|
||||
OpenAICompatibleImageModel
|
||||
} from '@ai-sdk/openai-compatible'
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { FetchFunction } from '@ai-sdk/provider-utils'
|
||||
import { loadApiKey, withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
|
||||
export const NEWAPI_PROVIDER_NAME = 'newapi' as const
|
||||
|
||||
export type NewApiEndpointType =
|
||||
| 'openai'
|
||||
| 'openai-response'
|
||||
| 'anthropic'
|
||||
| 'gemini'
|
||||
| 'image-generation'
|
||||
| 'jina-rerank'
|
||||
|
||||
export interface NewApiProviderSettings {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
headers?: Record<string, string>
|
||||
fetch?: FetchFunction
|
||||
endpointType?: NewApiEndpointType
|
||||
}
|
||||
|
||||
export interface NewApiProvider extends ProviderV3 {
|
||||
(modelId: string): LanguageModelV3
|
||||
languageModel(modelId: string): LanguageModelV3
|
||||
embeddingModel(modelId: string): EmbeddingModelV3
|
||||
imageModel(modelId: string): ImageModelV3
|
||||
}
|
||||
|
||||
export function createNewApi(options: NewApiProviderSettings = {}): NewApiProvider {
|
||||
const { baseURL = '', fetch: customFetch, endpointType } = options
|
||||
|
||||
const resolveApiKey = () =>
|
||||
loadApiKey({ apiKey: options.apiKey, environmentVariableName: 'NEWAPI_API_KEY', description: 'NewAPI' })
|
||||
|
||||
const authHeaders = (): Record<string, string> => ({
|
||||
Authorization: `Bearer ${resolveApiKey()}`,
|
||||
'Content-Type': 'application/json',
|
||||
...options.headers
|
||||
})
|
||||
|
||||
const url = ({ path }: { path: string; modelId: string }) => `${withoutTrailingSlash(baseURL)}${path}`
|
||||
|
||||
const createAnthropicModel = (modelId: string) => {
|
||||
const headers = authHeaders()
|
||||
return new AnthropicMessagesLanguageModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.anthropic`,
|
||||
baseURL,
|
||||
headers: () => ({ ...headers, 'x-api-key': resolveApiKey() }),
|
||||
fetch: customFetch,
|
||||
supportedUrls: () => ({ 'image/*': [/^https?:\/\/.*$/] })
|
||||
})
|
||||
}
|
||||
|
||||
const createGeminiModel = (modelId: string) => {
|
||||
const headers = authHeaders()
|
||||
return new GoogleGenerativeAILanguageModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.google`,
|
||||
baseURL,
|
||||
headers: () => ({ ...headers, 'x-goog-api-key': resolveApiKey() }),
|
||||
fetch: customFetch,
|
||||
generateId: () => `${NEWAPI_PROVIDER_NAME}-${Date.now()}`,
|
||||
supportedUrls: () => ({})
|
||||
})
|
||||
}
|
||||
|
||||
const createResponsesModel = (modelId: string) =>
|
||||
new OpenAIResponsesLanguageModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.openai-response`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
const createCompatibleModel = (modelId: string) =>
|
||||
new OpenAICompatibleChatLanguageModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.chat`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
const createChatModel = (modelId: string): LanguageModelV3 => {
|
||||
switch (endpointType) {
|
||||
case 'anthropic':
|
||||
return createAnthropicModel(modelId)
|
||||
case 'gemini':
|
||||
return createGeminiModel(modelId)
|
||||
case 'openai-response':
|
||||
return createResponsesModel(modelId)
|
||||
case 'openai':
|
||||
case 'image-generation':
|
||||
return createCompatibleModel(modelId)
|
||||
default:
|
||||
return createCompatibleModel(modelId)
|
||||
}
|
||||
}
|
||||
|
||||
const provider = (modelId: string) => createChatModel(modelId)
|
||||
provider.specificationVersion = 'v3' as const
|
||||
|
||||
provider.languageModel = createChatModel
|
||||
|
||||
provider.embeddingModel = (modelId: string) =>
|
||||
new OpenAICompatibleEmbeddingModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.embedding`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
provider.imageModel = (modelId: string) =>
|
||||
new OpenAICompatibleImageModel(modelId, {
|
||||
provider: `${NEWAPI_PROVIDER_NAME}.image`,
|
||||
url,
|
||||
headers: authHeaders,
|
||||
fetch: customFetch
|
||||
})
|
||||
|
||||
return provider as NewApiProvider
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
/**
|
||||
* Type System Tests for Auto-Extracted Provider Types
|
||||
*/
|
||||
|
||||
import type { AppProviderId } from '@renderer/aiCore/types'
|
||||
import { describe, expect, expectTypeOf, it } from 'vitest'
|
||||
|
||||
import { extensions } from '../index'
|
||||
|
||||
describe('Auto-Extracted Type System', () => {
|
||||
describe('Runtime and Type Consistency', () => {
|
||||
it('运行时 IDs 应该自动提取到类型系统', () => {
|
||||
// 从运行时获取所有 IDs(包括主 ID 和别名)
|
||||
const runtimeIds = extensions.flatMap((ext) => ext.getProviderIds())
|
||||
|
||||
// 🎯 Zero maintenance - 不再需要手动声明类型!
|
||||
// 类型系统会自动从 projectExtensions 数组中提取所有 IDs
|
||||
|
||||
// 验证主要的 project provider IDs
|
||||
const expectedMainIds: AppProviderId[] = [
|
||||
'google-vertex',
|
||||
'google-vertex-anthropic',
|
||||
'github-copilot-openai-compatible',
|
||||
'bedrock',
|
||||
'perplexity',
|
||||
'mistral',
|
||||
'huggingface',
|
||||
'gateway',
|
||||
'cerebras',
|
||||
'ollama'
|
||||
]
|
||||
|
||||
// 验证别名
|
||||
const expectedAliases: AppProviderId[] = [
|
||||
'vertexai',
|
||||
'vertexai-anthropic',
|
||||
'copilot',
|
||||
'github-copilot',
|
||||
'aws-bedrock',
|
||||
'hf',
|
||||
'hugging-face',
|
||||
'ai-gateway'
|
||||
]
|
||||
|
||||
// 验证所有期望的 ID 都存在于运行时
|
||||
;[...expectedMainIds, ...expectedAliases].forEach((id) => {
|
||||
expect(runtimeIds).toContain(id)
|
||||
})
|
||||
|
||||
// 验证数量一致
|
||||
const uniqueRuntimeIds = [...new Set(runtimeIds)]
|
||||
expect(uniqueRuntimeIds.length).toBeGreaterThanOrEqual(expectedMainIds.length + expectedAliases.length)
|
||||
})
|
||||
|
||||
it('每个 extension 应该至少有一个 provider ID', () => {
|
||||
extensions.forEach((ext) => {
|
||||
const ids = ext.getProviderIds()
|
||||
expect(ids.length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Inference - Auto-Extracted', () => {
|
||||
// 🎯 Zero maintenance! These tests validate compile-time type inference
|
||||
// 类型从 projectExtensions 数组自动提取,无需手动维护
|
||||
|
||||
it('应该接受核心 provider IDs', () => {
|
||||
// 编译时类型检查 - AppProviderId 包含所有 core IDs
|
||||
const coreIds: AppProviderId[] = [
|
||||
'openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'azure',
|
||||
'deepseek',
|
||||
'xai',
|
||||
'openai-compatible',
|
||||
'openrouter',
|
||||
'cherryin'
|
||||
]
|
||||
|
||||
// 运行时验证(确保类型存在)
|
||||
expect(coreIds.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('应该接受项目特定 provider IDs', () => {
|
||||
// 编译时类型检查 - 自动从 projectExtensions 提取
|
||||
const projectIds: AppProviderId[] = [
|
||||
'google-vertex',
|
||||
'google-vertex-anthropic',
|
||||
'github-copilot-openai-compatible',
|
||||
'bedrock',
|
||||
'perplexity',
|
||||
'mistral',
|
||||
'huggingface',
|
||||
'gateway',
|
||||
'cerebras',
|
||||
'ollama'
|
||||
]
|
||||
|
||||
// 运行时验证
|
||||
expect(projectIds.length).toBe(10)
|
||||
})
|
||||
|
||||
it('应该接受项目特定 provider 别名', () => {
|
||||
// 编译时类型检查 - 别名也自动提取
|
||||
const aliases: AppProviderId[] = [
|
||||
'vertexai',
|
||||
'vertexai-anthropic',
|
||||
'copilot',
|
||||
'github-copilot',
|
||||
'aws-bedrock',
|
||||
'hf',
|
||||
'hugging-face',
|
||||
'ai-gateway'
|
||||
]
|
||||
|
||||
// 运行时验证
|
||||
expect(aliases.length).toBe(8)
|
||||
})
|
||||
|
||||
it('AppProviderId 应该包含项目和核心的所有 IDs', () => {
|
||||
// 编译时验证 - 统一类型系统测试
|
||||
// ✅ 项目 IDs 应该在 AppProviderId 中
|
||||
type Check1 = 'google-vertex' extends AppProviderId ? true : false
|
||||
type Check2 = 'ollama' extends AppProviderId ? true : false
|
||||
type Check3 = 'vertexai' extends AppProviderId ? true : false
|
||||
|
||||
// ✅ 核心 IDs 也应该在 AppProviderId 中(统一类型系统)
|
||||
type Check4 = 'openai' extends AppProviderId ? true : false
|
||||
type Check5 = 'anthropic' extends AppProviderId ? true : false
|
||||
|
||||
expectTypeOf<Check1>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check2>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check3>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check4>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check5>().toEqualTypeOf<true>()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,235 +0,0 @@
|
||||
/**
|
||||
* Cherry Studio 项目特定的 Provider Extensions
|
||||
* 用于支持运行时动态导入的 AI Providers
|
||||
*/
|
||||
|
||||
import type { AmazonBedrockProvider } from '@ai-sdk/amazon-bedrock'
|
||||
import { type AmazonBedrockProviderSettings, createAmazonBedrock } from '@ai-sdk/amazon-bedrock'
|
||||
import { type CerebrasProviderSettings, createCerebras } from '@ai-sdk/cerebras'
|
||||
import { createGateway, type GatewayProviderSettings } from '@ai-sdk/gateway'
|
||||
import { createVertexAnthropic, type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||
import { createVertex, type GoogleVertexProvider, type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
|
||||
import { createGroq, type GroqProviderSettings } from '@ai-sdk/groq'
|
||||
import { createHuggingFace, type HuggingFaceProviderSettings } from '@ai-sdk/huggingface'
|
||||
import { createMistral, type MistralProviderSettings } from '@ai-sdk/mistral'
|
||||
import { createPerplexity, type PerplexityProviderSettings } from '@ai-sdk/perplexity'
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createTogetherAI, type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
|
||||
import { ProviderExtension, type ProviderExtensionConfig } from '@cherrystudio/ai-core/provider'
|
||||
import {
|
||||
createGitHubCopilotOpenAICompatible,
|
||||
type GitHubCopilotProviderSettings
|
||||
} from '@opeoginni/github-copilot-openai-compatible'
|
||||
import { SystemProviderIds } from '@types'
|
||||
import type { OllamaProviderSettings } from 'ollama-ai-provider-v2'
|
||||
import { createOllama } from 'ollama-ai-provider-v2'
|
||||
import { createVoyage, type VoyageProviderSettings } from 'voyage-ai-provider'
|
||||
|
||||
import { type AihubmixProviderSettings, createAihubmix } from '../custom/aihubmix-provider'
|
||||
import { createNewApi, type NewApiProviderSettings } from '../custom/newapi-provider'
|
||||
|
||||
/**
|
||||
* Google Vertex AI Extension
|
||||
*/
|
||||
export const GoogleVertexExtension = ProviderExtension.create({
|
||||
name: 'google-vertex',
|
||||
aliases: ['vertexai'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createVertex,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: GoogleVertexProvider) =>
|
||||
(config: NonNullable<Parameters<GoogleVertexProvider['tools']['googleSearch']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.googleSearch(config) }
|
||||
}),
|
||||
urlContext:
|
||||
(provider: GoogleVertexProvider) =>
|
||||
(config: NonNullable<Parameters<GoogleVertexProvider['tools']['urlContext']>[0]>) => ({
|
||||
tools: { urlContext: provider.tools.urlContext(config) }
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<GoogleVertexProviderSettings, GoogleVertexProvider, 'google-vertex'>)
|
||||
|
||||
/**
|
||||
* Google Vertex AI Anthropic Extension
|
||||
*/
|
||||
export const GoogleVertexAnthropicExtension = ProviderExtension.create({
|
||||
name: 'google-vertex-anthropic',
|
||||
aliases: ['vertexai-anthropic'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createVertexAnthropic,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: GoogleVertexAnthropicProvider) =>
|
||||
(config: NonNullable<Parameters<GoogleVertexAnthropicProvider['tools']['webSearch_20250305']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearch_20250305(config) }
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
GoogleVertexProviderSettings,
|
||||
GoogleVertexAnthropicProvider,
|
||||
'google-vertex-anthropic'
|
||||
>)
|
||||
|
||||
/**
|
||||
* GitHub Copilot Extension
|
||||
*/
|
||||
export const GitHubCopilotExtension = ProviderExtension.create({
|
||||
name: 'github-copilot-openai-compatible',
|
||||
aliases: ['copilot', 'github-copilot'] as const,
|
||||
supportsImageGeneration: false,
|
||||
create: (options?: GitHubCopilotProviderSettings) =>
|
||||
// GitHubCopilot并没有完整的实现ProviderV3
|
||||
createGitHubCopilotOpenAICompatible(options) as unknown as ProviderV3
|
||||
} as const satisfies ProviderExtensionConfig<
|
||||
GitHubCopilotProviderSettings,
|
||||
ProviderV3,
|
||||
'github-copilot-openai-compatible'
|
||||
>)
|
||||
|
||||
/**
|
||||
* Amazon Bedrock Extension
|
||||
*/
|
||||
export const BedrockExtension = ProviderExtension.create({
|
||||
name: 'bedrock',
|
||||
aliases: ['aws-bedrock'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createAmazonBedrock,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: AmazonBedrockProvider) =>
|
||||
(config: NonNullable<Parameters<AmazonBedrockProvider['tools']['webSearch_20260209']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearch_20260209(config) }
|
||||
}),
|
||||
urlContext:
|
||||
(provider: AmazonBedrockProvider) =>
|
||||
(config: NonNullable<Parameters<AmazonBedrockProvider['tools']['webFetch_20260209']>[0]>) => ({
|
||||
tools: { urlContext: provider.tools.webFetch_20260209(config) }
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<AmazonBedrockProviderSettings, AmazonBedrockProvider, 'bedrock'>)
|
||||
|
||||
/**
|
||||
* Perplexity Extension
|
||||
*/
|
||||
export const PerplexityExtension = ProviderExtension.create({
|
||||
name: 'perplexity',
|
||||
supportsImageGeneration: false,
|
||||
create: createPerplexity
|
||||
} as const satisfies ProviderExtensionConfig<PerplexityProviderSettings, ProviderV3, 'perplexity'>)
|
||||
|
||||
/**
|
||||
* Mistral Extension
|
||||
*/
|
||||
export const MistralExtension = ProviderExtension.create({
|
||||
name: 'mistral',
|
||||
supportsImageGeneration: false,
|
||||
create: createMistral
|
||||
} as const satisfies ProviderExtensionConfig<MistralProviderSettings, ProviderV3, 'mistral'>)
|
||||
|
||||
/**
|
||||
* HuggingFace Extension
|
||||
*/
|
||||
export const HuggingFaceExtension = ProviderExtension.create({
|
||||
name: 'huggingface',
|
||||
aliases: ['hf', 'hugging-face'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createHuggingFace
|
||||
} as const satisfies ProviderExtensionConfig<HuggingFaceProviderSettings, ProviderV3, 'huggingface'>)
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway Extension
|
||||
*/
|
||||
export const GatewayExtension = ProviderExtension.create({
|
||||
name: 'gateway',
|
||||
aliases: ['ai-gateway'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createGateway
|
||||
} as const satisfies ProviderExtensionConfig<GatewayProviderSettings, ProviderV3, 'gateway'>)
|
||||
|
||||
/**
|
||||
* Cerebras Extension
|
||||
*/
|
||||
export const CerebrasExtension = ProviderExtension.create({
|
||||
name: 'cerebras',
|
||||
supportsImageGeneration: false,
|
||||
create: createCerebras
|
||||
} as const satisfies ProviderExtensionConfig<CerebrasProviderSettings, ProviderV3, 'cerebras'>)
|
||||
|
||||
/**
|
||||
* Groq Extension
|
||||
*/
|
||||
export const GroqExtension = ProviderExtension.create({
|
||||
name: 'groq',
|
||||
supportsImageGeneration: false,
|
||||
create: createGroq
|
||||
} as const satisfies ProviderExtensionConfig<GroqProviderSettings, ProviderV3, 'groq'>)
|
||||
|
||||
/**
|
||||
* Ollama Extension
|
||||
*/
|
||||
export const OllamaExtension = ProviderExtension.create({
|
||||
name: 'ollama',
|
||||
supportsImageGeneration: false,
|
||||
create: (options?: OllamaProviderSettings) => createOllama(options)
|
||||
} as const satisfies ProviderExtensionConfig<OllamaProviderSettings, ProviderV3, 'ollama'>)
|
||||
|
||||
/**
|
||||
* AiHubMix Extension - multi-backend gateway (claude->anthropic, gemini->google, gpt->openai-responses)
|
||||
*/
|
||||
export const AiHubMixExtension = ProviderExtension.create({
|
||||
name: 'aihubmix',
|
||||
supportsImageGeneration: true,
|
||||
create: createAihubmix
|
||||
} as const satisfies ProviderExtensionConfig<AihubmixProviderSettings, ProviderV3, 'aihubmix'>)
|
||||
|
||||
/**
|
||||
* NewAPI Extension - multi-backend gateway routed by endpoint_type
|
||||
*/
|
||||
export const NewApiExtension = ProviderExtension.create({
|
||||
name: 'newapi',
|
||||
aliases: ['new-api'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createNewApi
|
||||
} as const satisfies ProviderExtensionConfig<NewApiProviderSettings, ProviderV3, 'newapi'>)
|
||||
|
||||
/**
|
||||
* Together AI Extension - chat and image generation
|
||||
*/
|
||||
export const TogetherAIExtension = ProviderExtension.create({
|
||||
name: 'togetherai',
|
||||
aliases: [SystemProviderIds.together] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createTogetherAI
|
||||
} as const satisfies ProviderExtensionConfig<TogetherAIProviderSettings, ProviderV3, 'togetherai'>)
|
||||
|
||||
/**
|
||||
* Voyage AI Extension - embeddings and reranking
|
||||
*/
|
||||
export const VoyageExtension = ProviderExtension.create({
|
||||
name: 'voyage',
|
||||
aliases: [SystemProviderIds.voyageai] as const,
|
||||
supportsImageGeneration: false,
|
||||
create: createVoyage
|
||||
} as const satisfies ProviderExtensionConfig<VoyageProviderSettings, ProviderV3, 'voyage'>)
|
||||
|
||||
/**
|
||||
* 所有项目特定的 Extensions
|
||||
*/
|
||||
export const extensions = [
|
||||
GoogleVertexExtension,
|
||||
GoogleVertexAnthropicExtension,
|
||||
GitHubCopilotExtension,
|
||||
BedrockExtension,
|
||||
PerplexityExtension,
|
||||
MistralExtension,
|
||||
HuggingFaceExtension,
|
||||
GatewayExtension,
|
||||
CerebrasExtension,
|
||||
OllamaExtension,
|
||||
AiHubMixExtension,
|
||||
NewApiExtension,
|
||||
VoyageExtension,
|
||||
TogetherAIExtension,
|
||||
GroqExtension
|
||||
] as const
|
||||
@@ -1,54 +0,0 @@
|
||||
import { extensionRegistry } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
|
||||
|
||||
import { type AppProviderId, appProviderIds } from '../types'
|
||||
import { extensions } from './extensions'
|
||||
|
||||
const logger = loggerService.withContext('ProviderFactory')
|
||||
|
||||
for (const extension of extensions) {
|
||||
if (!extensionRegistry.has(extension.config.name)) {
|
||||
extensionRegistry.register(extension)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 AI SDK Provider ID
|
||||
*
|
||||
* 使用运行时类型安全的 appProviderIds 统一解析
|
||||
* 特殊处理 Azure 端点检测和 OpenAI API 域名检测
|
||||
*
|
||||
* @param provider - Provider 配置对象
|
||||
* @returns AI SDK 标准 provider ID
|
||||
*/
|
||||
export function getAiSdkProviderId(provider: Provider): AppProviderId {
|
||||
// 1. 特殊处理:Azure 的 responses 端点检测(必须在别名解析之前)
|
||||
if (isAzureOpenAIProvider(provider)) {
|
||||
return isAzureResponsesEndpoint(provider) ? appProviderIds['azure-responses'] : appProviderIds.azure
|
||||
}
|
||||
|
||||
if (provider.id === SystemProviderIds.grok) {
|
||||
return appProviderIds['xai-responses']
|
||||
}
|
||||
|
||||
if (provider.id in appProviderIds) {
|
||||
return appProviderIds[provider.id]
|
||||
}
|
||||
|
||||
if (provider.type !== 'openai' && provider.type in appProviderIds) {
|
||||
return appProviderIds[provider.type]
|
||||
}
|
||||
|
||||
if (provider.apiHost.includes('api.openai.com')) {
|
||||
return appProviderIds['openai-chat']
|
||||
}
|
||||
|
||||
logger.warn('Provider ID not found in registered extensions, using as-is', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
registeredIds: Object.keys(appProviderIds)
|
||||
})
|
||||
return provider.id
|
||||
}
|
||||
@@ -1,409 +0,0 @@
|
||||
import { formatPrivateKey, hasProviderConfig, type StringKeys } from '@cherrystudio/ai-core/provider'
|
||||
import type { AppProviderId, AppProviderSettingsMap } from '@renderer/aiCore/types'
|
||||
import {
|
||||
getAwsBedrockAccessKeyId,
|
||||
getAwsBedrockApiKey,
|
||||
getAwsBedrockAuthType,
|
||||
getAwsBedrockRegion,
|
||||
getAwsBedrockSecretAccessKey
|
||||
} from '@renderer/hooks/useAwsBedrock'
|
||||
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { getProviderById } from '@renderer/services/ProviderService'
|
||||
import store from '@renderer/store'
|
||||
import { type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import {
|
||||
formatApiHost,
|
||||
formatOllamaApiHost,
|
||||
formatVertexApiHost,
|
||||
isWithTrailingSharp,
|
||||
routeToEndpoint
|
||||
} from '@renderer/utils/api'
|
||||
import {
|
||||
isAnthropicProvider,
|
||||
isAzureOpenAIProvider,
|
||||
isCherryAIProvider,
|
||||
isGeminiProvider,
|
||||
isOllamaProvider,
|
||||
isPerplexityProvider,
|
||||
isSupportStreamOptionsProvider,
|
||||
isVertexProvider
|
||||
} from '@renderer/utils/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { cloneDeep, isEmpty } from 'lodash'
|
||||
|
||||
import type { ProviderConfig } from '../types'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
// === Types ===
|
||||
|
||||
interface BaseConfig {
|
||||
baseURL: string
|
||||
apiKey: string
|
||||
}
|
||||
|
||||
interface BuilderContext {
|
||||
actualProvider: Provider
|
||||
model: Model
|
||||
baseConfig: BaseConfig
|
||||
endpoint?: string
|
||||
aiSdkProviderId: AppProviderId
|
||||
}
|
||||
|
||||
// === Host Formatting ===
|
||||
|
||||
type HostFormatter = {
|
||||
match: (provider: Provider) => boolean
|
||||
format: (provider: Provider, appendApiVersion: boolean) => string
|
||||
}
|
||||
|
||||
// WARNING: if any changes are made here, please sync it to src/main/aiCore/provider/providerConfig.ts:formatProviderApiHost
|
||||
export function formatProviderApiHost(provider: Provider): Provider {
|
||||
const formatted = { ...provider }
|
||||
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
|
||||
|
||||
if (formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
|
||||
}
|
||||
|
||||
// Anthropic is special: uses anthropicApiHost as source and syncs both fields
|
||||
if (isAnthropicProvider(provider)) {
|
||||
const baseHost = formatted.anthropicApiHost || formatted.apiHost
|
||||
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
|
||||
if (!formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatted.apiHost
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
const formatters: HostFormatter[] = [
|
||||
{
|
||||
match: (p) => p.id === SystemProviderIds.copilot || p.id === SystemProviderIds.github,
|
||||
format: (p) => formatApiHost(p.apiHost, false)
|
||||
},
|
||||
{ match: isCherryAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
|
||||
{ match: isPerplexityProvider, format: (p) => formatApiHost(p.apiHost, false) },
|
||||
{ match: isOllamaProvider, format: (p) => formatOllamaApiHost(p.apiHost) },
|
||||
{ match: isGeminiProvider, format: (p, av) => formatApiHost(p.apiHost, av, 'v1beta') },
|
||||
{ match: isAzureOpenAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
|
||||
{ match: isVertexProvider, format: (p) => formatVertexApiHost(p as Parameters<typeof formatVertexApiHost>[0]) }
|
||||
]
|
||||
|
||||
const formatter = formatters.find((f) => f.match(provider))
|
||||
formatted.apiHost = formatter
|
||||
? formatter.format(formatted, appendApiVersion)
|
||||
: formatApiHost(formatted.apiHost, appendApiVersion)
|
||||
|
||||
return formatted
|
||||
}
|
||||
|
||||
// === SDK Config Building ===
|
||||
|
||||
type ConfigBuilderEntry = {
|
||||
match: (provider: Provider, aiSdkProviderId: AppProviderId) => boolean
|
||||
build: (ctx: BuilderContext) => ProviderConfig | Promise<ProviderConfig>
|
||||
}
|
||||
|
||||
export function providerToAiSdkConfig(
|
||||
actualProvider: Provider,
|
||||
model: Model
|
||||
): ProviderConfig | Promise<ProviderConfig> {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
|
||||
|
||||
const ctx: BuilderContext = {
|
||||
actualProvider,
|
||||
model,
|
||||
baseConfig: { baseURL, apiKey: actualProvider.apiKey },
|
||||
endpoint,
|
||||
aiSdkProviderId
|
||||
}
|
||||
|
||||
const builders: ConfigBuilderEntry[] = [
|
||||
{ match: (p) => p.id === SystemProviderIds.copilot, build: buildCopilotConfig },
|
||||
{ match: (p) => p.id === 'cherryai', build: buildCherryAIConfig },
|
||||
{ match: (p) => p.id === 'anthropic' && p.authType === 'oauth', build: buildAnthropicConfig },
|
||||
{ match: (p) => isOllamaProvider(p), build: buildOllamaConfig },
|
||||
{ match: (p) => isAzureOpenAIProvider(p), build: buildAzureConfig },
|
||||
{ match: (_, id) => id === 'bedrock', build: buildBedrockConfig },
|
||||
{ match: (_, id) => id === 'google-vertex', build: buildVertexConfig },
|
||||
{ match: (_, id) => id === 'cherryin', build: buildCherryinConfig },
|
||||
{ match: (_, id) => id === 'newapi', build: buildNewApiConfig },
|
||||
{ match: (_, id) => id === 'aihubmix', build: buildAiHubMixConfig }
|
||||
]
|
||||
|
||||
const builder = builders.find((b) => b.match(actualProvider, aiSdkProviderId))
|
||||
if (builder) {
|
||||
return builder.build(ctx)
|
||||
}
|
||||
|
||||
// SDK-supported provider → generic config; otherwise → openai-compatible fallback
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
return buildGenericProviderConfig(ctx)
|
||||
}
|
||||
return buildOpenAICompatibleConfig(ctx)
|
||||
}
|
||||
|
||||
// === Public API ===
|
||||
|
||||
export function getActualProvider(model: Model): Provider {
|
||||
return adaptProvider({ provider: getProviderByModel(model), model })
|
||||
}
|
||||
|
||||
export function adaptProvider({ provider }: { provider: Provider; model?: Model }): Provider {
|
||||
return formatProviderApiHost(cloneDeep(provider))
|
||||
}
|
||||
|
||||
export function isModernSdkSupported(provider: Provider): boolean {
|
||||
return hasProviderConfig(getAiSdkProviderId(provider))
|
||||
}
|
||||
|
||||
// === Config Builders ===
|
||||
|
||||
function buildCommonOptions(ctx: BuilderContext) {
|
||||
const options: Record<string, any> = {
|
||||
headers: {
|
||||
...defaultAppHeaders(),
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
}
|
||||
if (ctx.aiSdkProviderId === 'openai') {
|
||||
options.headers['X-Api-Key'] = ctx.baseConfig.apiKey
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
async function buildCopilotConfig(ctx: BuilderContext): Promise<ProviderConfig<'github-copilot-openai-compatible'>> {
|
||||
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
|
||||
const headers = { ...COPILOT_DEFAULT_HEADERS, ...storedHeaders }
|
||||
const { token } = await window.api.copilot.getToken(headers)
|
||||
|
||||
return {
|
||||
providerId: 'github-copilot-openai-compatible',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
apiKey: token,
|
||||
headers: { ...headers, ...ctx.actualProvider.extra_headers },
|
||||
name: ctx.actualProvider.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildOllamaConfig(ctx: BuilderContext): ProviderConfig<'ollama'> {
|
||||
const headers: ProviderConfig<'ollama'>['providerSettings']['headers'] = {
|
||||
...defaultAppHeaders(),
|
||||
...ctx.actualProvider.extra_headers
|
||||
}
|
||||
if (!isEmpty(ctx.baseConfig.apiKey)) {
|
||||
headers.Authorization = `Bearer ${ctx.baseConfig.apiKey}`
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'ollama',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: { ...ctx.baseConfig, headers }
|
||||
}
|
||||
}
|
||||
|
||||
function buildBedrockConfig(ctx: BuilderContext): ProviderConfig<'bedrock'> {
|
||||
const authType = getAwsBedrockAuthType()
|
||||
const region = getAwsBedrockRegion()
|
||||
|
||||
const base = { providerId: 'bedrock' as const, endpoint: ctx.endpoint }
|
||||
|
||||
if (authType === 'apiKey') {
|
||||
return { ...base, providerSettings: { ...ctx.baseConfig, region, apiKey: getAwsBedrockApiKey() } }
|
||||
}
|
||||
return {
|
||||
...base,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
region,
|
||||
accessKeyId: getAwsBedrockAccessKeyId(),
|
||||
secretAccessKey: getAwsBedrockSecretAccessKey()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildVertexConfig(
|
||||
ctx: BuilderContext
|
||||
): ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'> {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
|
||||
const { project, location, googleCredentials } = createVertexProvider(ctx.actualProvider)
|
||||
// Vertex 上的 Claude 模型走 google-vertex-anthropic variant
|
||||
const isAnthropic = ctx.aiSdkProviderId === 'google-vertex-anthropic' || ctx.model.id.startsWith('claude')
|
||||
const baseURL = ctx.baseConfig.baseURL + (isAnthropic ? '/publishers/anthropic/models' : '/publishers/google')
|
||||
const creds = { ...googleCredentials, privateKey: formatPrivateKey(googleCredentials.privateKey) }
|
||||
|
||||
return {
|
||||
providerId: isAnthropic ? 'google-vertex-anthropic' : 'google-vertex',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: { ...ctx.baseConfig, baseURL, project, location, googleCredentials: creds }
|
||||
} as ProviderConfig<'google-vertex'> | ProviderConfig<'google-vertex-anthropic'>
|
||||
}
|
||||
|
||||
function buildCherryinConfig(ctx: BuilderContext): ProviderConfig<'cherryin'> {
|
||||
const cherryinProvider = getProviderById(SystemProviderIds.cherryin)
|
||||
|
||||
return {
|
||||
providerId: 'cherryin',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
endpointType: ctx.model.endpoint_type,
|
||||
anthropicBaseURL: cherryinProvider ? cherryinProvider.anthropicApiHost + '/v1' : undefined,
|
||||
geminiBaseURL: cherryinProvider ? cherryinProvider.apiHost + '/v1beta' : undefined,
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function buildCherryAIConfig(ctx: BuilderContext): Promise<ProviderConfig<'openai-compatible'>> {
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
name: ctx.actualProvider.id,
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers },
|
||||
fetch: async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
const signature = await window.api.cherryai.generateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: init?.body && typeof init.body === 'string' ? JSON.parse(init.body) : undefined
|
||||
})
|
||||
return fetch(input, { ...init, headers: { ...init?.headers, ...signature } })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function formatAzureBaseURL(baseURL: string, forAnthropic: boolean): string {
|
||||
// Normalize: strip trailing /v1 and /openai that user may have included
|
||||
const normalized = baseURL.replace(/\/v1$/, '').replace(/\/openai$/, '')
|
||||
// Azure OpenAI endpoints need /openai suffix; Azure Anthropic does not
|
||||
return forAnthropic ? normalized : normalized + '/openai'
|
||||
}
|
||||
|
||||
function buildAzureConfig(
|
||||
ctx: BuilderContext
|
||||
): ProviderConfig<'azure'> | ProviderConfig<'azure-responses'> | ProviderConfig<'azure-anthropic'> {
|
||||
// Azure 上的 Claude 模型走 azure-anthropic variant(内部使用 Anthropic SDK)
|
||||
if (ctx.model.id.startsWith('claude')) {
|
||||
return {
|
||||
providerId: 'azure-anthropic',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
baseURL: formatAzureBaseURL(ctx.baseConfig.baseURL, true),
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const apiVersion = ctx.actualProvider.apiVersion?.trim()
|
||||
const useResponsesMode = apiVersion && ['preview', 'v1'].includes(apiVersion)
|
||||
|
||||
const providerSettings: ProviderConfig<'azure'>['providerSettings'] = {
|
||||
...ctx.baseConfig,
|
||||
baseURL: formatAzureBaseURL(ctx.baseConfig.baseURL, false),
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
|
||||
}
|
||||
|
||||
if (apiVersion) {
|
||||
providerSettings.apiVersion = apiVersion
|
||||
if (!useResponsesMode) {
|
||||
providerSettings.useDeploymentBasedUrls = true
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: useResponsesMode ? 'azure-responses' : 'azure',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings
|
||||
} as ProviderConfig<'azure'> | ProviderConfig<'azure-responses'>
|
||||
}
|
||||
|
||||
async function buildAnthropicConfig(ctx: BuilderContext): Promise<ProviderConfig<'anthropic'>> {
|
||||
const oauthToken: string = await window.api.anthropic_oauth.getAccessToken()
|
||||
|
||||
return {
|
||||
providerId: 'anthropic',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
baseURL: 'https://api.anthropic.com/v1',
|
||||
apiKey: '',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
Authorization: `Bearer ${oauthToken}`
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildOpenAICompatibleConfig(ctx: BuilderContext): ProviderConfig<'openai-compatible'> {
|
||||
const commonOptions = buildCommonOptions(ctx)
|
||||
const includeUsage = isSupportStreamOptionsProvider(ctx.actualProvider)
|
||||
? store.getState().settings.openAI?.streamOptions?.includeUsage
|
||||
: undefined
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: { ...ctx.baseConfig, ...commonOptions, name: ctx.actualProvider.id, includeUsage }
|
||||
}
|
||||
}
|
||||
|
||||
function buildGenericProviderConfig(ctx: BuilderContext): ProviderConfig {
|
||||
const commonOptions = buildCommonOptions(ctx)
|
||||
|
||||
return {
|
||||
providerId: ctx.aiSdkProviderId as StringKeys<AppProviderSettingsMap>,
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: { ...ctx.baseConfig, ...commonOptions }
|
||||
}
|
||||
}
|
||||
|
||||
function buildAiHubMixConfig(ctx: BuilderContext): ProviderConfig<'aihubmix'> {
|
||||
return {
|
||||
providerId: 'aihubmix',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function formatNewApiBaseURL(baseURL: string, endpointType?: string): string {
|
||||
switch (endpointType) {
|
||||
case 'gemini':
|
||||
return formatApiHost(baseURL, true, 'v1beta')
|
||||
case 'anthropic':
|
||||
return formatApiHost(baseURL, false)
|
||||
default:
|
||||
return formatApiHost(baseURL, true)
|
||||
}
|
||||
}
|
||||
|
||||
function buildNewApiConfig(ctx: BuilderContext): ProviderConfig<'newapi'> {
|
||||
const baseURL = formatNewApiBaseURL(ctx.baseConfig.baseURL, ctx.model.endpoint_type)
|
||||
|
||||
return {
|
||||
providerId: 'newapi',
|
||||
endpoint: ctx.endpoint,
|
||||
providerSettings: {
|
||||
...ctx.baseConfig,
|
||||
baseURL,
|
||||
endpointType: ctx.model.endpoint_type,
|
||||
headers: { ...defaultAppHeaders(), ...ctx.actualProvider.extra_headers }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,333 +0,0 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`listModels > AIHubMix > should convert real AIHubMix response with model_id and model_name 1`] = `
|
||||
[
|
||||
{
|
||||
"description": "Qwen 3.6, the native vision-language Plus series model.",
|
||||
"group": "aihubmix",
|
||||
"id": "qwen3.6-plus",
|
||||
"name": "Qwen3.6 Plus",
|
||||
"provider": "aihubmix",
|
||||
},
|
||||
{
|
||||
"description": "Claude Sonnet 4.6 delivers frontier intelligence at scale.",
|
||||
"group": "aihubmix",
|
||||
"id": "claude-sonnet-4-6",
|
||||
"name": "Claude Sonnet 4.6",
|
||||
"provider": "aihubmix",
|
||||
},
|
||||
{
|
||||
"description": "GPT-5.4 is our frontier model for complex professional work.",
|
||||
"group": "aihubmix",
|
||||
"id": "gpt-5.4",
|
||||
"name": "GPT 5.4",
|
||||
"provider": "aihubmix",
|
||||
},
|
||||
{
|
||||
"description": "A new-generation professional-grade multimodal video-creation model.",
|
||||
"group": "aihubmix",
|
||||
"id": "doubao-seedance-2-0-260128",
|
||||
"name": "Doubao Seedance 2.0 260128",
|
||||
"provider": "aihubmix",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > Gemini > should strip models/ prefix and use displayName from real response 1`] = `
|
||||
[
|
||||
{
|
||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.5-flash",
|
||||
"name": "Gemini 2.5 Flash",
|
||||
"provider": "gemini",
|
||||
},
|
||||
{
|
||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.5-pro",
|
||||
"name": "Gemini 2.5 Pro",
|
||||
"provider": "gemini",
|
||||
},
|
||||
{
|
||||
"description": "Gemini 2.0 Flash",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.0-flash",
|
||||
"name": "Gemini 2.0 Flash",
|
||||
"provider": "gemini",
|
||||
},
|
||||
{
|
||||
"description": "Stable version of Gemini 2.0 Flash, our fast and versatile multimodal model for scaling across diverse tasks, released in January of 2025.",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.0-flash-001",
|
||||
"name": "Gemini 2.0 Flash 001",
|
||||
"provider": "gemini",
|
||||
},
|
||||
{
|
||||
"description": "Stable version of Gemini 2.0 Flash-Lite",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.0-flash-lite-001",
|
||||
"name": "Gemini 2.0 Flash-Lite 001",
|
||||
"provider": "gemini",
|
||||
},
|
||||
{
|
||||
"description": "Gemini 2.0 Flash-Lite",
|
||||
"group": "gemini",
|
||||
"id": "gemini-2.0-flash-lite",
|
||||
"name": "Gemini 2.0 Flash-Lite",
|
||||
"provider": "gemini",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > OpenAI-compatible (DeepSeek) > should convert real DeepSeek response 1`] = `
|
||||
[
|
||||
{
|
||||
"group": "deepseek",
|
||||
"id": "deepseek-chat",
|
||||
"name": "deepseek-chat",
|
||||
"owned_by": "deepseek",
|
||||
"provider": "deepseek",
|
||||
},
|
||||
{
|
||||
"group": "deepseek",
|
||||
"id": "deepseek-reasoner",
|
||||
"name": "deepseek-reasoner",
|
||||
"owned_by": "deepseek",
|
||||
"provider": "deepseek",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > OpenAI-compatible (Groq) > should convert real Groq response with owned_by 1`] = `
|
||||
[
|
||||
{
|
||||
"group": "qwen",
|
||||
"id": "qwen/qwen3-32b",
|
||||
"name": "qwen/qwen3-32b",
|
||||
"owned_by": "Alibaba Cloud",
|
||||
"provider": "groq",
|
||||
},
|
||||
{
|
||||
"group": "groq",
|
||||
"id": "groq/compound-mini",
|
||||
"name": "groq/compound-mini",
|
||||
"owned_by": "Groq",
|
||||
"provider": "groq",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > OpenAI-compatible (SiliconFlow) > should handle nested slash IDs for group extraction 1`] = `
|
||||
[
|
||||
{
|
||||
"group": "Pro",
|
||||
"id": "Pro/MiniMaxAI/MiniMax-M2.5",
|
||||
"name": "Pro/MiniMaxAI/MiniMax-M2.5",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
{
|
||||
"group": "Pro",
|
||||
"id": "Pro/zai-org/GLM-5",
|
||||
"name": "Pro/zai-org/GLM-5",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
{
|
||||
"group": "Pro",
|
||||
"id": "Pro/moonshotai/Kimi-K2.5",
|
||||
"name": "Pro/moonshotai/Kimi-K2.5",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
{
|
||||
"group": "Pro",
|
||||
"id": "Pro/zai-org/GLM-4.7",
|
||||
"name": "Pro/zai-org/GLM-4.7",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
{
|
||||
"group": "deepseek-ai",
|
||||
"id": "deepseek-ai/DeepSeek-V3.2",
|
||||
"name": "deepseek-ai/DeepSeek-V3.2",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
{
|
||||
"group": "Pro",
|
||||
"id": "Pro/deepseek-ai/DeepSeek-V3.2",
|
||||
"name": "Pro/deepseek-ai/DeepSeek-V3.2",
|
||||
"owned_by": "",
|
||||
"provider": "silicon",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > OpenRouter > should merge chat and embedding endpoints from real response 1`] = `
|
||||
[
|
||||
{
|
||||
"group": "xiaomi",
|
||||
"id": "xiaomi/mimo-v2-omni",
|
||||
"name": "xiaomi/mimo-v2-omni",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "xiaomi",
|
||||
"id": "xiaomi/mimo-v2-pro",
|
||||
"name": "xiaomi/mimo-v2-pro",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "minimax",
|
||||
"id": "minimax/minimax-m2.7",
|
||||
"name": "minimax/minimax-m2.7",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "openai",
|
||||
"id": "openai/gpt-5.4-nano",
|
||||
"name": "openai/gpt-5.4-nano",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "openai",
|
||||
"id": "openai/gpt-5.4-mini",
|
||||
"name": "openai/gpt-5.4-mini",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "mistralai",
|
||||
"id": "mistralai/mistral-small-2603",
|
||||
"name": "mistralai/mistral-small-2603",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "z-ai",
|
||||
"id": "z-ai/glm-5-turbo",
|
||||
"name": "z-ai/glm-5-turbo",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "x-ai",
|
||||
"id": "x-ai/grok-4.20-multi-agent-beta",
|
||||
"name": "x-ai/grok-4.20-multi-agent-beta",
|
||||
"owned_by": null,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
{
|
||||
"group": "openai",
|
||||
"id": "openai/text-embedding-3-large",
|
||||
"name": "openai/text-embedding-3-large",
|
||||
"owned_by": undefined,
|
||||
"provider": "openrouter",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > PPIO > should merge all three endpoints from real response 1`] = `
|
||||
[
|
||||
{
|
||||
"group": "minimax",
|
||||
"id": "minimax/minimax-m2.7",
|
||||
"name": "minimax/minimax-m2.7",
|
||||
"owned_by": "unknown",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "minimax",
|
||||
"id": "minimax/minimax-m2.5-highspeed",
|
||||
"name": "minimax/minimax-m2.5-highspeed",
|
||||
"owned_by": "unknown",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "qwen",
|
||||
"id": "qwen/qwen3.5-27b",
|
||||
"name": "qwen/qwen3.5-27b",
|
||||
"owned_by": "unknown",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "qwen",
|
||||
"id": "qwen/qwen3.5-122b-a10b",
|
||||
"name": "qwen/qwen3.5-122b-a10b",
|
||||
"owned_by": "unknown",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "qwen",
|
||||
"id": "qwen/qwen3.5-35b-a3b",
|
||||
"name": "qwen/qwen3.5-35b-a3b",
|
||||
"owned_by": "unknown",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "BAAI",
|
||||
"id": "BAAI/bge-m3",
|
||||
"name": "BAAI/bge-m3",
|
||||
"owned_by": "BAAI",
|
||||
"provider": "ppio",
|
||||
},
|
||||
{
|
||||
"group": "BAAI",
|
||||
"id": "BAAI/bge-reranker-v2-m3",
|
||||
"name": "BAAI/bge-reranker-v2-m3",
|
||||
"owned_by": "BAAI",
|
||||
"provider": "ppio",
|
||||
},
|
||||
]
|
||||
`;
|
||||
|
||||
exports[`listModels > Together > should use display_name and organization from real response 1`] = `
|
||||
[
|
||||
{
|
||||
"description": null,
|
||||
"group": "hexgrad",
|
||||
"id": "hexgrad/Kokoro-82M",
|
||||
"name": "Kokoro 82M",
|
||||
"owned_by": "Hexgrad",
|
||||
"provider": "together",
|
||||
},
|
||||
{
|
||||
"description": null,
|
||||
"group": "cartesia",
|
||||
"id": "cartesia/sonic",
|
||||
"name": "Cartesia Sonic",
|
||||
"owned_by": "Cartesia",
|
||||
"provider": "together",
|
||||
},
|
||||
{
|
||||
"description": null,
|
||||
"group": "black-forest-labs",
|
||||
"id": "black-forest-labs/FLUX.1-krea-dev",
|
||||
"name": "FLUX.1 Krea [dev]",
|
||||
"owned_by": "Black Forest Labs",
|
||||
"provider": "together",
|
||||
},
|
||||
{
|
||||
"description": null,
|
||||
"group": "google",
|
||||
"id": "google/imagen-4.0-preview",
|
||||
"name": "Google Imagen 4.0 Preview",
|
||||
"owned_by": "Google",
|
||||
"provider": "together",
|
||||
},
|
||||
{
|
||||
"description": null,
|
||||
"group": "cartesia",
|
||||
"id": "cartesia/sonic-2",
|
||||
"name": "Cartesia Sonic 2",
|
||||
"owned_by": "Cartesia",
|
||||
"provider": "together",
|
||||
},
|
||||
]
|
||||
`;
|
||||
@@ -1,411 +0,0 @@
|
||||
/**
|
||||
* ModelListService conversion tests
|
||||
* Uses real API responses captured from providers to verify model conversion
|
||||
*/
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mockGetFromApi = vi.fn()
|
||||
vi.mock('@ai-sdk/provider-utils', () => ({
|
||||
createJsonResponseHandler: vi.fn(() => 'json-handler'),
|
||||
createJsonErrorResponseHandler: vi.fn(() => 'error-handler'),
|
||||
getFromApi: (...args: unknown[]) => mockGetFromApi(...args),
|
||||
zodSchema: vi.fn((s: unknown) => s)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils', () => ({
|
||||
formatApiHost: (host: string) => host?.replace(/\/$/, ''),
|
||||
withoutTrailingSlash: (s: string) => s?.replace(/\/$/, ''),
|
||||
getLowerBaseModelName: (id: string) => id.toLowerCase()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', () => ({
|
||||
isAIGatewayProvider: (p: Provider) => p.id === 'gateway',
|
||||
isGeminiProvider: (p: Provider) => p.id === 'gemini' || p.type === 'gemini',
|
||||
isOllamaProvider: (p: Provider) => p.id === 'ollama' || p.type === 'ollama'
|
||||
}))
|
||||
|
||||
vi.mock('@shared/utils', () => ({
|
||||
defaultAppHeaders: () => ({ 'X-App': 'CherryStudio' })
|
||||
}))
|
||||
|
||||
const { listModels } = await import('../listModels')
|
||||
|
||||
// === Real API response fixtures (captured 2026-03-19) ===
|
||||
|
||||
// From https://openrouter.ai/api/v1/models (public, no auth)
|
||||
const REAL_OPENROUTER = {
|
||||
data: [
|
||||
{ id: 'xiaomi/mimo-v2-omni', object: 'model', created: 1773863703, owned_by: null },
|
||||
{ id: 'xiaomi/mimo-v2-pro', object: 'model', created: 1773863643, owned_by: null },
|
||||
{ id: 'minimax/minimax-m2.7', object: 'model', created: 1773836697, owned_by: null },
|
||||
{ id: 'openai/gpt-5.4-nano', object: 'model', created: 1773748187, owned_by: null },
|
||||
{ id: 'openai/gpt-5.4-mini', object: 'model', created: 1773748178, owned_by: null },
|
||||
{ id: 'mistralai/mistral-small-2603', object: 'model', created: 1773695685, owned_by: null },
|
||||
{ id: 'z-ai/glm-5-turbo', object: 'model', created: 1773583573, owned_by: null },
|
||||
{ id: 'x-ai/grok-4.20-multi-agent-beta', object: 'model', created: 1773325367, owned_by: null }
|
||||
]
|
||||
}
|
||||
|
||||
// From https://api.deepseek.com/v1/models
|
||||
const REAL_DEEPSEEK = {
|
||||
object: 'list',
|
||||
data: [
|
||||
{ id: 'deepseek-chat', object: 'model', owned_by: 'deepseek' },
|
||||
{ id: 'deepseek-reasoner', object: 'model', owned_by: 'deepseek' }
|
||||
]
|
||||
}
|
||||
|
||||
// From https://generativelanguage.googleapis.com/v1beta/models
|
||||
const REAL_GEMINI = {
|
||||
models: [
|
||||
{
|
||||
name: 'models/gemini-2.5-flash',
|
||||
displayName: 'Gemini 2.5 Flash',
|
||||
description:
|
||||
'Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.'
|
||||
},
|
||||
{
|
||||
name: 'models/gemini-2.5-pro',
|
||||
displayName: 'Gemini 2.5 Pro',
|
||||
description: 'Stable release (June 17th, 2025) of Gemini 2.5 Pro'
|
||||
},
|
||||
{
|
||||
name: 'models/gemini-2.0-flash',
|
||||
displayName: 'Gemini 2.0 Flash',
|
||||
description: 'Gemini 2.0 Flash'
|
||||
},
|
||||
{
|
||||
name: 'models/gemini-2.0-flash-001',
|
||||
displayName: 'Gemini 2.0 Flash 001',
|
||||
description:
|
||||
'Stable version of Gemini 2.0 Flash, our fast and versatile multimodal model for scaling across diverse tasks, released in January of 2025.'
|
||||
},
|
||||
{
|
||||
name: 'models/gemini-2.0-flash-lite-001',
|
||||
displayName: 'Gemini 2.0 Flash-Lite 001',
|
||||
description: 'Stable version of Gemini 2.0 Flash-Lite'
|
||||
},
|
||||
{
|
||||
name: 'models/gemini-2.0-flash-lite',
|
||||
displayName: 'Gemini 2.0 Flash-Lite',
|
||||
description: 'Gemini 2.0 Flash-Lite'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// From https://api.together.xyz/v1/models
|
||||
const REAL_TOGETHER = [
|
||||
{ id: 'hexgrad/Kokoro-82M', display_name: 'Kokoro 82M', organization: 'Hexgrad', description: null },
|
||||
{ id: 'cartesia/sonic', display_name: 'Cartesia Sonic', organization: 'Cartesia', description: null },
|
||||
{
|
||||
id: 'black-forest-labs/FLUX.1-krea-dev',
|
||||
display_name: 'FLUX.1 Krea [dev]',
|
||||
organization: 'Black Forest Labs',
|
||||
description: null
|
||||
},
|
||||
{
|
||||
id: 'google/imagen-4.0-preview',
|
||||
display_name: 'Google Imagen 4.0 Preview',
|
||||
organization: 'Google',
|
||||
description: null
|
||||
},
|
||||
{ id: 'cartesia/sonic-2', display_name: 'Cartesia Sonic 2', organization: 'Cartesia', description: null }
|
||||
]
|
||||
|
||||
// From https://api.siliconflow.cn/v1/models (OpenAI-compatible)
|
||||
const REAL_SILICONFLOW = {
|
||||
data: [
|
||||
{ id: 'Pro/MiniMaxAI/MiniMax-M2.5', object: 'model', owned_by: '' },
|
||||
{ id: 'Pro/zai-org/GLM-5', object: 'model', owned_by: '' },
|
||||
{ id: 'Pro/moonshotai/Kimi-K2.5', object: 'model', owned_by: '' },
|
||||
{ id: 'Pro/zai-org/GLM-4.7', object: 'model', owned_by: '' },
|
||||
{ id: 'deepseek-ai/DeepSeek-V3.2', object: 'model', owned_by: '' },
|
||||
{ id: 'Pro/deepseek-ai/DeepSeek-V3.2', object: 'model', owned_by: '' }
|
||||
]
|
||||
}
|
||||
|
||||
// From https://api.groq.com/openai/v1/models
|
||||
const REAL_GROQ = {
|
||||
data: [
|
||||
{ id: 'qwen/qwen3-32b', object: 'model', created: 1748396646, owned_by: 'Alibaba Cloud' },
|
||||
{ id: 'groq/compound-mini', object: 'model', created: 1756949707, owned_by: 'Groq' }
|
||||
]
|
||||
}
|
||||
|
||||
// From https://api.ppinfra.com/v3/openai/models
|
||||
const REAL_PPIO_CHAT = {
|
||||
data: [
|
||||
{ id: 'minimax/minimax-m2.7', object: 'model', owned_by: 'unknown' },
|
||||
{ id: 'minimax/minimax-m2.5-highspeed', object: 'model', owned_by: 'unknown' },
|
||||
{ id: 'qwen/qwen3.5-27b', object: 'model', owned_by: 'unknown' },
|
||||
{ id: 'qwen/qwen3.5-122b-a10b', object: 'model', owned_by: 'unknown' },
|
||||
{ id: 'qwen/qwen3.5-35b-a3b', object: 'model', owned_by: 'unknown' }
|
||||
]
|
||||
}
|
||||
|
||||
// From https://aihubmix.com/api/v1/models (custom schema with model_id/model_name)
|
||||
const REAL_AIHUBMIX = {
|
||||
data: [
|
||||
{
|
||||
model_id: 'qwen3.6-plus',
|
||||
model_name: 'Qwen3.6 Plus',
|
||||
developer_id: 13,
|
||||
desc: 'Qwen 3.6, the native vision-language Plus series model.',
|
||||
pricing: { cache_read: 0.0282, cache_write: 0.3525, input: 0.282, output: 1.692 },
|
||||
types: 'llm',
|
||||
features: 'tools,function_calling,structured_outputs,web,long_context,thinking',
|
||||
input_modalities: 'text,image,video',
|
||||
endpoints: '',
|
||||
max_output: 64000,
|
||||
context_length: 991000
|
||||
},
|
||||
{
|
||||
model_id: 'claude-sonnet-4-6',
|
||||
model_name: 'Claude Sonnet 4.6',
|
||||
developer_id: 2,
|
||||
desc: 'Claude Sonnet 4.6 delivers frontier intelligence at scale.',
|
||||
pricing: { cache_read: 0.3, cache_write: 3.75, input: 3, output: 15 },
|
||||
types: 'llm',
|
||||
features: 'thinking,tools,function_calling,structured_outputs',
|
||||
input_modalities: 'text,image',
|
||||
endpoints: 'chat_completions,gemini_api,claude_api',
|
||||
max_output: 64000,
|
||||
context_length: 1000000
|
||||
},
|
||||
{
|
||||
model_id: 'gpt-5.4',
|
||||
model_name: 'GPT 5.4',
|
||||
developer_id: 12,
|
||||
desc: 'GPT-5.4 is our frontier model for complex professional work.',
|
||||
pricing: { cache_read: 0.25, input: 2.5, output: 15 },
|
||||
types: 'llm',
|
||||
features: 'thinking,function_calling,structured_outputs,web,tools',
|
||||
input_modalities: 'text,image',
|
||||
endpoints: '',
|
||||
max_output: 128000,
|
||||
context_length: 400000
|
||||
},
|
||||
{
|
||||
model_id: 'doubao-seedance-2-0-260128',
|
||||
model_name: 'Doubao Seedance 2.0 260128',
|
||||
developer_id: 4,
|
||||
desc: 'A new-generation professional-grade multimodal video-creation model.',
|
||||
pricing: { input: 2, output: 0 },
|
||||
types: 'video',
|
||||
features: '',
|
||||
input_modalities: 'image,text',
|
||||
endpoints: '',
|
||||
max_output: 0,
|
||||
context_length: 0
|
||||
}
|
||||
],
|
||||
message: '',
|
||||
success: true
|
||||
}
|
||||
|
||||
// === Helpers ===
|
||||
|
||||
function makeProvider(overrides: Partial<Provider> & { id: string }): Provider {
|
||||
return {
|
||||
name: overrides.id,
|
||||
type: 'openai',
|
||||
apiKey: 'sk-test',
|
||||
apiHost: 'https://api.example.com/v1',
|
||||
models: [],
|
||||
isSystem: true,
|
||||
enabled: true,
|
||||
...overrides
|
||||
} as Provider
|
||||
}
|
||||
|
||||
function assertValidModels(models: { id: string; name: string; provider: string; group: string }[]) {
|
||||
expect(models.length).toBeGreaterThan(0)
|
||||
for (const m of models) {
|
||||
expect(m.id).toBeTruthy()
|
||||
expect(typeof m.id).toBe('string')
|
||||
expect(m.id).toBe(m.id.trim())
|
||||
expect(m.name).toBeTruthy()
|
||||
expect(typeof m.provider).toBe('string')
|
||||
expect(typeof m.group).toBe('string')
|
||||
}
|
||||
}
|
||||
|
||||
// === Tests ===
|
||||
|
||||
beforeEach(() => {
|
||||
mockGetFromApi.mockReset()
|
||||
vi.stubGlobal('window', { ...globalThis.window, keyv: { get: vi.fn(), set: vi.fn() } })
|
||||
})
|
||||
|
||||
describe('listModels', () => {
|
||||
describe('OpenAI-compatible (DeepSeek)', () => {
|
||||
it('should convert real DeepSeek response', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_DEEPSEEK })
|
||||
const models = await listModels(makeProvider({ id: 'deepseek' }))
|
||||
assertValidModels(models)
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI-compatible (SiliconFlow)', () => {
|
||||
it('should handle nested slash IDs for group extraction', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_SILICONFLOW })
|
||||
const models = await listModels(makeProvider({ id: 'silicon' }))
|
||||
assertValidModels(models)
|
||||
// "Pro/MiniMaxAI/MiniMax-M2.5" -> group "Pro"
|
||||
expect(models[0].group).toBe('Pro')
|
||||
// "deepseek-ai/DeepSeek-V3.2" -> group "deepseek-ai"
|
||||
expect(models[4].group).toBe('deepseek-ai')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI-compatible (Groq)', () => {
|
||||
it('should convert real Groq response with owned_by', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_GROQ })
|
||||
const models = await listModels(makeProvider({ id: 'groq' }))
|
||||
assertValidModels(models)
|
||||
expect(models[0].owned_by).toBe('Alibaba Cloud')
|
||||
expect(models[1].owned_by).toBe('Groq')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Gemini', () => {
|
||||
it('should strip models/ prefix and use displayName from real response', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_GEMINI })
|
||||
const models = await listModels(
|
||||
makeProvider({ id: 'gemini', type: 'gemini', apiHost: 'https://generativelanguage.googleapis.com/v1beta' })
|
||||
)
|
||||
assertValidModels(models)
|
||||
for (const m of models) {
|
||||
expect(m.id).not.toMatch(/^models\//)
|
||||
}
|
||||
// displayName should be used as name
|
||||
expect(models[0].name).toBe('Gemini 2.5 Flash')
|
||||
expect(models[0].id).toBe('gemini-2.5-flash')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Together', () => {
|
||||
it('should use display_name and organization from real response', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_TOGETHER })
|
||||
const models = await listModels(makeProvider({ id: 'together' }))
|
||||
assertValidModels(models)
|
||||
expect(models[0].name).toBe('Kokoro 82M')
|
||||
expect(models[0].owned_by).toBe('Hexgrad')
|
||||
expect(models[0].group).toBe('hexgrad')
|
||||
// FLUX model with org "Black Forest Labs"
|
||||
expect(models[2].name).toBe('FLUX.1 Krea [dev]')
|
||||
expect(models[2].owned_by).toBe('Black Forest Labs')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenRouter', () => {
|
||||
it('should merge chat and embedding endpoints from real response', async () => {
|
||||
mockGetFromApi
|
||||
.mockResolvedValueOnce({ value: REAL_OPENROUTER })
|
||||
.mockResolvedValueOnce({ value: { data: [{ id: 'openai/text-embedding-3-large', object: 'model' }] } })
|
||||
const models = await listModels(makeProvider({ id: 'openrouter' }))
|
||||
assertValidModels(models)
|
||||
expect(models).toHaveLength(REAL_OPENROUTER.data.length + 1)
|
||||
// Slash IDs should produce correct group
|
||||
expect(models.find((m) => m.id === 'xiaomi/mimo-v2-omni')?.group).toBe('xiaomi')
|
||||
expect(models.find((m) => m.id === 'openai/gpt-5.4-nano')?.group).toBe('openai')
|
||||
expect(models.find((m) => m.id === 'x-ai/grok-4.20-multi-agent-beta')?.group).toBe('x-ai')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should deduplicate across endpoints', async () => {
|
||||
mockGetFromApi
|
||||
.mockResolvedValueOnce({ value: { data: [REAL_OPENROUTER.data[0]] } })
|
||||
.mockResolvedValueOnce({ value: { data: [REAL_OPENROUTER.data[0]] } })
|
||||
const models = await listModels(makeProvider({ id: 'openrouter' }))
|
||||
expect(models).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should handle embedding endpoint failure', async () => {
|
||||
mockGetFromApi.mockResolvedValueOnce({ value: REAL_OPENROUTER }).mockRejectedValueOnce(new Error('404 Not Found'))
|
||||
const models = await listModels(makeProvider({ id: 'openrouter' }))
|
||||
expect(models).toHaveLength(REAL_OPENROUTER.data.length)
|
||||
})
|
||||
})
|
||||
|
||||
describe('PPIO', () => {
|
||||
it('should merge all three endpoints from real response', async () => {
|
||||
mockGetFromApi
|
||||
.mockResolvedValueOnce({ value: REAL_PPIO_CHAT })
|
||||
.mockResolvedValueOnce({ value: { data: [{ id: 'BAAI/bge-m3', object: 'model', owned_by: 'BAAI' }] } })
|
||||
.mockResolvedValueOnce({
|
||||
value: { data: [{ id: 'BAAI/bge-reranker-v2-m3', object: 'model', owned_by: 'BAAI' }] }
|
||||
})
|
||||
const models = await listModels(makeProvider({ id: 'ppio' }))
|
||||
assertValidModels(models)
|
||||
expect(models).toHaveLength(7)
|
||||
expect(models.find((m) => m.id === 'BAAI/bge-m3')?.group).toBe('BAAI')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should handle partial endpoint failures', async () => {
|
||||
mockGetFromApi
|
||||
.mockResolvedValueOnce({ value: REAL_PPIO_CHAT })
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
const models = await listModels(makeProvider({ id: 'ppio' }))
|
||||
expect(models).toHaveLength(REAL_PPIO_CHAT.data.length)
|
||||
})
|
||||
})
|
||||
|
||||
describe('AIHubMix', () => {
|
||||
it('should convert real AIHubMix response with model_id and model_name', async () => {
|
||||
mockGetFromApi.mockResolvedValue({ value: REAL_AIHUBMIX })
|
||||
const models = await listModels(makeProvider({ id: 'aihubmix' }))
|
||||
assertValidModels(models)
|
||||
expect(models).toHaveLength(4)
|
||||
// model_name should be used as name
|
||||
expect(models[0].name).toBe('Qwen3.6 Plus')
|
||||
expect(models[0].id).toBe('qwen3.6-plus')
|
||||
expect(models[0].description).toBe('Qwen 3.6, the native vision-language Plus series model.')
|
||||
// No slash in ID -> group falls back to provider id
|
||||
expect(models[0].group).toBe('aihubmix')
|
||||
expect(models[1].name).toBe('Claude Sonnet 4.6')
|
||||
expect(models[2].name).toBe('GPT 5.4')
|
||||
expect(models[3].name).toBe('Doubao Seedance 2.0 260128')
|
||||
expect(models).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should deduplicate by model_id', async () => {
|
||||
const duped = {
|
||||
...REAL_AIHUBMIX,
|
||||
data: [REAL_AIHUBMIX.data[0], REAL_AIHUBMIX.data[0], REAL_AIHUBMIX.data[1]]
|
||||
}
|
||||
mockGetFromApi.mockResolvedValue({ value: duped })
|
||||
const models = await listModels(makeProvider({ id: 'aihubmix' }))
|
||||
expect(models).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Unsupported providers', () => {
|
||||
it.each([
|
||||
['gateway', { id: 'gateway' }],
|
||||
['aws-bedrock', { id: 'aws-bedrock' }],
|
||||
['anthropic', { id: 'anthropic' }],
|
||||
['vertex-anthropic', { id: 'vertex-anthro', type: 'vertex-anthropic' as any }]
|
||||
])('should return empty for %s', async (_, overrides) => {
|
||||
const models = await listModels(makeProvider(overrides as any))
|
||||
expect(models).toEqual([])
|
||||
expect(mockGetFromApi).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should return empty on network error', async () => {
|
||||
mockGetFromApi.mockRejectedValue(new Error('ECONNREFUSED'))
|
||||
const models = await listModels(makeProvider({ id: 'openai' }))
|
||||
expect(models).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1 +0,0 @@
|
||||
export { listModels } from './listModels'
|
||||
@@ -1,393 +0,0 @@
|
||||
/**
|
||||
* ModelListService - Unified model listing service
|
||||
* Uses Strategy Registry pattern for provider-specific model fetching
|
||||
*/
|
||||
|
||||
import {
|
||||
createJsonErrorResponseHandler,
|
||||
createJsonResponseHandler,
|
||||
getFromApi as aiSdkGetFromApi,
|
||||
zodSchema
|
||||
} from '@ai-sdk/provider-utils'
|
||||
import { cacheService } from '@data/CacheService'
|
||||
import { loggerService } from '@logger'
|
||||
import type { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { formatApiHost, withoutTrailingSlash } from '@renderer/utils'
|
||||
import { isAIGatewayProvider, isGeminiProvider, isOllamaProvider } from '@renderer/utils/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import * as z from 'zod'
|
||||
|
||||
import {
|
||||
AIHubMixModelsResponseSchema,
|
||||
GeminiModelsResponseSchema,
|
||||
GitHubModelsResponseSchema,
|
||||
NewApiModelsResponseSchema,
|
||||
OllamaTagsResponseSchema,
|
||||
OpenAIModelsResponseSchema,
|
||||
OVMSConfigResponseSchema,
|
||||
TogetherModelsResponseSchema
|
||||
} from './schemas'
|
||||
|
||||
const logger = loggerService.withContext('ModelListService')
|
||||
|
||||
// === Types ===
|
||||
|
||||
type ModelFetcher = {
|
||||
match: (provider: Provider) => boolean
|
||||
fetch: (provider: Provider, signal?: AbortSignal) => Promise<Model[]>
|
||||
}
|
||||
|
||||
// === API Layer ===
|
||||
|
||||
const ApiErrorSchema = z.object({
|
||||
error: z
|
||||
.object({
|
||||
message: z.string().optional(),
|
||||
code: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
message: z.string().optional()
|
||||
})
|
||||
|
||||
type ApiError = z.infer<typeof ApiErrorSchema>
|
||||
|
||||
async function getFromApi<T>({
|
||||
url,
|
||||
headers,
|
||||
responseSchema,
|
||||
abortSignal
|
||||
}: {
|
||||
url: string
|
||||
headers?: Record<string, string>
|
||||
responseSchema: z.ZodType<T>
|
||||
abortSignal?: AbortSignal
|
||||
}): Promise<T> {
|
||||
const { value } = await aiSdkGetFromApi({
|
||||
url,
|
||||
headers,
|
||||
successfulResponseHandler: createJsonResponseHandler(zodSchema(responseSchema)),
|
||||
failedResponseHandler: createJsonErrorResponseHandler({
|
||||
errorSchema: zodSchema(ApiErrorSchema),
|
||||
errorToMessage: (error: ApiError) => error.error?.message || error.message || 'Unknown error'
|
||||
}),
|
||||
abortSignal
|
||||
})
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// === Helpers ===
|
||||
|
||||
function getApiKey(provider: Provider): string {
|
||||
const keys = provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${provider.id}:last_used_key`
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = cacheService.getCasual<string>(keyName)
|
||||
if (!lastUsedKey) {
|
||||
cacheService.setCasual(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
cacheService.setCasual(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
function defaultHeaders(provider: Provider): Record<string, string> {
|
||||
const apiKey = getApiKey(provider)
|
||||
return {
|
||||
...defaultAppHeaders(),
|
||||
...(apiKey ? { Authorization: `Bearer ${apiKey}`, 'X-Api-Key': apiKey } : {}),
|
||||
...provider.extra_headers
|
||||
}
|
||||
}
|
||||
|
||||
function defaultGroup(modelId: string, providerId: string): string {
|
||||
const parts = modelId.split('/')
|
||||
return parts.length > 1 ? parts[0] : providerId
|
||||
}
|
||||
|
||||
function toModel(id: string, provider: Provider, extra?: Partial<Model>): Model {
|
||||
return {
|
||||
id,
|
||||
name: extra?.name || id,
|
||||
provider: provider.id,
|
||||
group: extra?.group || defaultGroup(id, provider.id),
|
||||
...extra
|
||||
}
|
||||
}
|
||||
|
||||
function dedup<T>(items: T[], getId: (item: T) => string | undefined): T[] {
|
||||
const seen = new Set<string>()
|
||||
return items.filter((item) => {
|
||||
const id = getId(item)?.trim()
|
||||
if (!id || seen.has(id)) return false
|
||||
seen.add(id)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
function pickPreferredString(values: Array<unknown>): string | undefined {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'string') {
|
||||
const trimmed = value.trim()
|
||||
if (trimmed.length > 0) {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// === Fetchers ===
|
||||
|
||||
const ollamaFetcher: ModelFetcher = {
|
||||
match: (p) => isOllamaProvider(p),
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = withoutTrailingSlash(provider.apiHost)
|
||||
.replace(/\/v1$/, '')
|
||||
.replace(/\/api$/, '')
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/api/tags`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OllamaTagsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response.models, (m) => m.name).map((m) => toModel(m.name, provider, { owned_by: 'ollama' }))
|
||||
}
|
||||
}
|
||||
|
||||
const geminiFetcher: ModelFetcher = {
|
||||
match: (p) => isGeminiProvider(p),
|
||||
fetch: async (provider, signal) => {
|
||||
let baseUrl = withoutTrailingSlash(provider.apiHost)
|
||||
baseUrl = baseUrl.replace(/\/v1(beta)?$/, '')
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/v1beta/models?key=${getApiKey(provider)}`,
|
||||
headers: { ...defaultAppHeaders(), ...provider.extra_headers },
|
||||
responseSchema: GeminiModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response.models, (m) => m.name).map((m) => {
|
||||
const id = m.name.startsWith('models/') ? m.name.slice(7) : m.name
|
||||
return toModel(id, provider, { name: m.displayName || id, description: m.description })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const githubFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.github,
|
||||
fetch: async (provider, signal) => {
|
||||
const [catalogResponse, v1Response] = await Promise.all([
|
||||
getFromApi({
|
||||
url: 'https://models.github.ai/catalog/models',
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: GitHubModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}),
|
||||
getFromApi({
|
||||
url: 'https://models.github.ai/v1/models',
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}).catch(() => ({ data: [] as { id: string; owned_by?: string }[] }))
|
||||
])
|
||||
const registryModels = catalogResponse.map((m) =>
|
||||
toModel(m.id, provider, {
|
||||
name: m.name || m.id,
|
||||
description: pickPreferredString([m.summary, m.description]),
|
||||
owned_by: m.publisher
|
||||
})
|
||||
)
|
||||
const v1Models = v1Response.data.map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
|
||||
return dedup([...registryModels, ...v1Models], (m) => m.id)
|
||||
}
|
||||
}
|
||||
|
||||
const ovmsFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.ovms,
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = formatApiHost(withoutTrailingSlash(provider.apiHost).replace(/\/v1$/, ''), true, 'v1')
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/config`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OVMSConfigResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
const entries = Object.entries(response).filter(([, info]) =>
|
||||
info?.model_version_status?.some((v) => v?.state === 'AVAILABLE')
|
||||
)
|
||||
return dedup(entries, ([name]) => name).map(([name]) => toModel(name, provider, { owned_by: 'ovms' }))
|
||||
}
|
||||
}
|
||||
|
||||
const togetherFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.together,
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = formatApiHost(provider.apiHost)
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/models`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: TogetherModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response, (m) => m.id).map((m) =>
|
||||
toModel(m.id, provider, {
|
||||
name: m.display_name || m.id,
|
||||
description: m.description,
|
||||
owned_by: m.organization
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const newApiFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds['new-api'] || p.type === 'new-api' || p.id === SystemProviderIds.cherryin,
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = formatApiHost(provider.apiHost)
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/models`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: NewApiModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response.data, (m) => m.id).map((m) =>
|
||||
toModel(m.id, provider, {
|
||||
owned_by: m.owned_by,
|
||||
supported_endpoint_types: m.supported_endpoint_types as EndpointType[] | undefined
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const openRouterFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.openrouter,
|
||||
fetch: async (provider, signal) => {
|
||||
const [modelsResponse, embedModelsResponse] = await Promise.all([
|
||||
getFromApi({
|
||||
url: 'https://openrouter.ai/api/v1/models',
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}),
|
||||
getFromApi({
|
||||
url: 'https://openrouter.ai/api/v1/embeddings/models',
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}).catch(() => ({ data: [] }))
|
||||
])
|
||||
const all = [...modelsResponse.data, ...embedModelsResponse.data]
|
||||
return dedup(all, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
|
||||
}
|
||||
}
|
||||
|
||||
const ppioFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.ppio,
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = formatApiHost(provider.apiHost)
|
||||
const [chat, embed, reranker] = await Promise.all([
|
||||
getFromApi({
|
||||
url: `${baseUrl}/models`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}),
|
||||
getFromApi({
|
||||
url: `${baseUrl}/models?model_type=embedding`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}).catch(() => ({ data: [] })),
|
||||
getFromApi({
|
||||
url: `${baseUrl}/models?model_type=reranker`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
}).catch(() => ({ data: [] }))
|
||||
])
|
||||
const all = [...chat.data, ...embed.data, ...reranker.data]
|
||||
return dedup(all, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
|
||||
}
|
||||
}
|
||||
|
||||
const aiHubMixFetcher: ModelFetcher = {
|
||||
match: (p) => p.id === SystemProviderIds.aihubmix,
|
||||
fetch: async (provider, signal) => {
|
||||
const response = await getFromApi({
|
||||
url: `https://aihubmix.com/api/v1/models`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: AIHubMixModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response.data, (m) => m.model_id).map((m) =>
|
||||
toModel(m.model_id, provider, {
|
||||
name: m.model_name || m.model_id,
|
||||
description: m.desc
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/** Default fallback: OpenAI-compatible /models endpoint */
|
||||
const openAICompatibleFetcher: ModelFetcher = {
|
||||
match: () => true,
|
||||
fetch: async (provider, signal) => {
|
||||
const baseUrl = formatApiHost(provider.apiHost)
|
||||
const response = await getFromApi({
|
||||
url: `${baseUrl}/models`,
|
||||
headers: defaultHeaders(provider),
|
||||
responseSchema: OpenAIModelsResponseSchema,
|
||||
abortSignal: signal
|
||||
})
|
||||
return dedup(response.data, (m) => m.id).map((m) => toModel(m.id, provider, { owned_by: m.owned_by }))
|
||||
}
|
||||
}
|
||||
|
||||
// === Registry (order matters: first match wins) ===
|
||||
|
||||
const fetchers: ModelFetcher[] = [
|
||||
aiHubMixFetcher,
|
||||
ollamaFetcher,
|
||||
geminiFetcher,
|
||||
githubFetcher,
|
||||
ovmsFetcher,
|
||||
togetherFetcher,
|
||||
newApiFetcher,
|
||||
openRouterFetcher,
|
||||
ppioFetcher,
|
||||
openAICompatibleFetcher // always-match fallback, must be last
|
||||
]
|
||||
|
||||
// === Unsupported providers (skip before registry lookup) ===
|
||||
|
||||
const UNSUPPORTED_PROVIDERS = new Set<string>([SystemProviderIds['aws-bedrock'], SystemProviderIds.anthropic])
|
||||
|
||||
function isUnsupported(provider: Provider): boolean {
|
||||
return isAIGatewayProvider(provider) || UNSUPPORTED_PROVIDERS.has(provider.id) || provider.type === 'vertex-anthropic'
|
||||
}
|
||||
|
||||
// === Public API ===
|
||||
|
||||
export async function listModels(provider: Provider, abortSignal?: AbortSignal): Promise<Model[]> {
|
||||
try {
|
||||
if (isUnsupported(provider)) {
|
||||
logger.warn('Provider does not support model listing via listModels', { providerId: provider.id })
|
||||
return []
|
||||
}
|
||||
|
||||
const fetcher = fetchers.find((f) => f.match(provider))!
|
||||
return await fetcher.fetch(provider, abortSignal)
|
||||
} catch (error) {
|
||||
logger.error('Error listing models:', error as Error, { providerId: provider.id })
|
||||
return []
|
||||
}
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
/**
|
||||
* API Response Schemas for model listing
|
||||
* Used exclusively by listModels.ts
|
||||
*
|
||||
* All object schemas use z.looseObject() to tolerate unknown fields
|
||||
* from providers — prevents parse failures when APIs add new fields.
|
||||
*/
|
||||
import * as z from 'zod'
|
||||
|
||||
// === OpenAI-compatible (also used by OpenRouter, PPIO, etc.) ===
|
||||
|
||||
export const OpenAIModelsResponseSchema = z.object({
|
||||
data: z.array(
|
||||
z.looseObject({
|
||||
id: z.string(),
|
||||
object: z.string().optional().default('model'),
|
||||
created: z.number().optional(),
|
||||
owned_by: z.string().optional()
|
||||
})
|
||||
),
|
||||
object: z.string().optional()
|
||||
})
|
||||
|
||||
// === Ollama ===
|
||||
|
||||
export const OllamaTagsResponseSchema = z.object({
|
||||
models: z.array(
|
||||
z.looseObject({
|
||||
name: z.string(),
|
||||
model: z.string().optional(),
|
||||
modified_at: z.string().optional(),
|
||||
size: z.number().optional(),
|
||||
digest: z.string().optional(),
|
||||
details: z
|
||||
.looseObject({
|
||||
parent_model: z.string().optional(),
|
||||
format: z.string().optional(),
|
||||
family: z.string().optional(),
|
||||
families: z.array(z.string()).optional(),
|
||||
parameter_size: z.string().optional(),
|
||||
quantization_level: z.string().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
// === Gemini ===
|
||||
|
||||
export const GeminiModelsResponseSchema = z.object({
|
||||
models: z.array(
|
||||
z.looseObject({
|
||||
name: z.string(),
|
||||
displayName: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
version: z.string().optional(),
|
||||
baseModelId: z.string().optional(),
|
||||
inputTokenLimit: z.number().optional(),
|
||||
outputTokenLimit: z.number().optional(),
|
||||
supportedGenerationMethods: z.array(z.string()).optional()
|
||||
})
|
||||
),
|
||||
nextPageToken: z.string().optional()
|
||||
})
|
||||
|
||||
// === GitHub Models ===
|
||||
|
||||
export const GitHubModelsResponseSchema = z.array(
|
||||
z.looseObject({
|
||||
id: z.string(),
|
||||
summary: z.string().optional(),
|
||||
publisher: z.string().optional(),
|
||||
name: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
version: z.string().optional()
|
||||
})
|
||||
)
|
||||
|
||||
// === Together ===
|
||||
|
||||
export const TogetherModelsResponseSchema = z.array(
|
||||
z.looseObject({
|
||||
id: z.string(),
|
||||
display_name: z.string().optional(),
|
||||
organization: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
context_length: z.number().optional(),
|
||||
pricing: z
|
||||
.looseObject({
|
||||
input: z.number().optional(),
|
||||
output: z.number().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
)
|
||||
|
||||
// === NewAPI (extends OpenAI with endpoint types) ===
|
||||
|
||||
export const NewApiModelsResponseSchema = z.object({
|
||||
data: z.array(
|
||||
z.looseObject({
|
||||
id: z.string(),
|
||||
object: z.string().optional().default('model'),
|
||||
created: z.number().optional(),
|
||||
owned_by: z.string().optional(),
|
||||
supported_endpoint_types: z
|
||||
.array(z.string())
|
||||
.nullable()
|
||||
.optional()
|
||||
.transform((v) => v ?? undefined)
|
||||
})
|
||||
),
|
||||
object: z.string().optional()
|
||||
})
|
||||
|
||||
// === OVMS (OpenVINO Model Server) ===
|
||||
|
||||
export const OVMSConfigResponseSchema = z.record(
|
||||
z.string(),
|
||||
z.object({
|
||||
model_version_status: z
|
||||
.array(
|
||||
z.looseObject({
|
||||
state: z.string(),
|
||||
status: z
|
||||
.looseObject({
|
||||
error_code: z.string().optional(),
|
||||
error_message: z.string().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
)
|
||||
.optional()
|
||||
})
|
||||
)
|
||||
|
||||
// === AIHubMix ===
|
||||
|
||||
export const AIHubMixModelsResponseSchema = z.object({
|
||||
data: z.array(
|
||||
z.looseObject({
|
||||
model_id: z.string(),
|
||||
model_name: z.string().optional(),
|
||||
developer_id: z.number().optional(),
|
||||
desc: z.string().optional(),
|
||||
pricing: z
|
||||
.looseObject({
|
||||
cache_read: z.number().optional(),
|
||||
cache_write: z.number().optional(),
|
||||
input: z.number().optional(),
|
||||
output: z.number().optional()
|
||||
})
|
||||
.optional(),
|
||||
types: z.string().optional(),
|
||||
features: z.string().optional(),
|
||||
input_modalities: z.string().optional(),
|
||||
endpoints: z.string().optional(),
|
||||
max_output: z.number().optional(),
|
||||
context_length: z.number().optional()
|
||||
})
|
||||
),
|
||||
message: z.string().optional(),
|
||||
success: z.boolean().optional()
|
||||
})
|
||||
@@ -1,140 +0,0 @@
|
||||
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
|
||||
import type { Assistant, KnowledgeReference } from '@renderer/types'
|
||||
import type { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract'
|
||||
import { REFERENCE_PROMPT } from '@shared/config/prompts'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 知识库搜索工具
|
||||
* 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
*/
|
||||
export const knowledgeSearchTool = (
|
||||
assistant: Assistant,
|
||||
extractedKeywords: KnowledgeExtractResults,
|
||||
topicId: string,
|
||||
userMessage?: string
|
||||
) => {
|
||||
return tool({
|
||||
description: `Knowledge base search tool for retrieving information from user's private knowledge base. This searches your local collection of documents, web content, notes, and other materials you have stored.
|
||||
|
||||
This tool has been configured with search parameters based on the conversation context:
|
||||
- Prepared queries: ${extractedKeywords.question.map((q) => `"${q}"`).join(', ')}
|
||||
- Query rewrite: "${extractedKeywords.rewrite}"
|
||||
|
||||
You can use this tool as-is, or provide additionalContext to refine the search focus within the knowledge base.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
additionalContext: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Optional additional context or specific focus to enhance the knowledge search')
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
// try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
return []
|
||||
}
|
||||
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
|
||||
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file,
|
||||
metadata: ref.metadata
|
||||
}))
|
||||
|
||||
// TODO 在工具函数中添加搜索缓存机制
|
||||
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
|
||||
|
||||
// 返回结果
|
||||
return knowledgeReferencesData
|
||||
},
|
||||
toModelOutput: ({ output: results }) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.length > 0) {
|
||||
summary = `Found ${results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(results, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the knowledge references, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
|
||||
return {
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
||||
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
||||
|
||||
export default knowledgeSearchTool
|
||||
@@ -1,47 +0,0 @@
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import store from '@renderer/store'
|
||||
import { selectMemoryConfig } from '@renderer/store/memory'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
|
||||
/**
|
||||
* 🧠 基础记忆搜索工具
|
||||
* AI 可以主动调用的简单记忆搜索
|
||||
*/
|
||||
export const memorySearchTool = (assistantId: string) => {
|
||||
return tool({
|
||||
description: 'Search through conversation memories and stored facts for relevant context',
|
||||
inputSchema: z.object({
|
||||
query: z.string().describe('Search query to find relevant memories'),
|
||||
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||
}),
|
||||
execute: async ({ query, limit = 5 }) => {
|
||||
const globalMemoryEnabled = await preferenceService.get('feature.memory.enabled')
|
||||
if (!globalMemoryEnabled) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
|
||||
if (!memoryConfig.llmModel || !memoryConfig.embeddingModel) {
|
||||
return []
|
||||
}
|
||||
|
||||
const currentUserId = await preferenceService.get('feature.memory.current_user_id')
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistantId, currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
|
||||
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
|
||||
@@ -1,204 +0,0 @@
|
||||
import { webSearchService } from '@renderer/services/WebSearchService'
|
||||
import type { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
import type { ExtractResults } from '@renderer/utils/extract'
|
||||
import { REFERENCE_PROMPT } from '@shared/config/prompts'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 使用预提取关键词的网络搜索工具
|
||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
*/
|
||||
export const webSearchToolWithPreExtractedKeywords = (
|
||||
webSearchProviderId: WebSearchProvider['id'],
|
||||
extractedKeywords: {
|
||||
question: string[]
|
||||
links?: string[]
|
||||
},
|
||||
requestId: string
|
||||
) => {
|
||||
const webSearchProvider = webSearchService.getWebSearchProvider(webSearchProviderId)
|
||||
|
||||
return tool({
|
||||
description: `Web search tool for finding current information, news, and real-time data from the internet.
|
||||
|
||||
This tool has been configured with search parameters based on the conversation context:
|
||||
- Prepared queries: ${extractedKeywords.question.map((q) => `"${q}"`).join(', ')}${
|
||||
extractedKeywords.links?.length
|
||||
? `
|
||||
- Relevant URLs: ${extractedKeywords.links.join(', ')}`
|
||||
: ''
|
||||
}
|
||||
|
||||
You can use this tool as-is to search with the prepared queries, or provide additionalContext to refine or replace the search terms.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
additionalContext: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
}
|
||||
}
|
||||
|
||||
let searchResults: WebSearchProviderResponse = {
|
||||
query: '',
|
||||
results: []
|
||||
}
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return searchResults
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
searchResults = await webSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
|
||||
return searchResults
|
||||
},
|
||||
toModelOutput: ({ output: results }) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.query && results.results.length > 0) {
|
||||
summary = `Found ${results.results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
|
||||
const citationData = results.results.map((result, index) => ({
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
content: result.content,
|
||||
url: result.url
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the search results, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
return {
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// export const webSearchToolWithExtraction = (
|
||||
// webSearchProviderId: WebSearchProvider['id'],
|
||||
// requestId: string,
|
||||
// assistant: Assistant
|
||||
// ) => {
|
||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
// return tool({
|
||||
// name: 'web_search_with_extraction',
|
||||
// description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
// inputSchema: z.object({
|
||||
// userMessage: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// }),
|
||||
// lastAnswer: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// })
|
||||
// }),
|
||||
// outputSchema: z.object({
|
||||
// extractedKeywords: z.object({
|
||||
// question: z.array(z.string()),
|
||||
// links: z.array(z.string()).optional()
|
||||
// }),
|
||||
// searchResults: z.array(
|
||||
// z.object({
|
||||
// query: z.string(),
|
||||
// results: WebSearchProviderResult
|
||||
// })
|
||||
// )
|
||||
// }),
|
||||
// execute: async ({ userMessage, lastAnswer }) => {
|
||||
// const lastUserMessage: Message = {
|
||||
// id: requestId,
|
||||
// role: userMessage.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
|
||||
// const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
// ? {
|
||||
// id: requestId + '_answer',
|
||||
// role: lastAnswer.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
// : undefined
|
||||
|
||||
// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
// shouldWebSearch: true,
|
||||
// shouldKnowledgeSearch: false,
|
||||
// lastAnswer: lastAnswerMessage
|
||||
// })
|
||||
|
||||
// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
// return 'No search needed or extraction failed'
|
||||
// }
|
||||
|
||||
// const searchQueries = extractResults.websearch.question
|
||||
// const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
// for (const query of searchQueries) {
|
||||
// // 构建单个查询的ExtractResults结构
|
||||
// const queryExtractResults: ExtractResults = {
|
||||
// websearch: {
|
||||
// question: [query],
|
||||
// links: extractResults.websearch.links
|
||||
// }
|
||||
// }
|
||||
// const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
// searchResults.push({
|
||||
// query,
|
||||
// results: response
|
||||
// })
|
||||
// }
|
||||
|
||||
// return { extractedKeywords: extractResults.websearch, searchResults }
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||
|
||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
@@ -1,656 +0,0 @@
|
||||
/**
|
||||
* AI SDK Span Adapter
|
||||
*
|
||||
* 将 AI SDK 的 telemetry 数据转换为现有的 SpanEntity 格式
|
||||
* 注意 AI SDK 的层级结构:ai.xxx 是一个层级,ai.xxx.xxx 是对应层级下的子集
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkSpanAdapter')
|
||||
|
||||
export interface AiSdkSpanData {
|
||||
span: Span
|
||||
topicId?: string
|
||||
modelName?: string
|
||||
}
|
||||
|
||||
// 扩展接口用于访问span的内部数据
|
||||
interface SpanWithInternals extends Span {
|
||||
_spanProcessor?: any
|
||||
_attributes?: Record<string, any>
|
||||
_events?: any[]
|
||||
name?: string
|
||||
startTime?: [number, number]
|
||||
endTime?: [number, number] | null
|
||||
status?: { code: SpanStatusCode; message?: string }
|
||||
kind?: SpanKind
|
||||
ended?: boolean
|
||||
parentSpanId?: string
|
||||
links?: any[]
|
||||
}
|
||||
|
||||
export class AiSdkSpanAdapter {
|
||||
/**
|
||||
* 将 AI SDK span 转换为 SpanEntity 格式
|
||||
*/
|
||||
static convertToSpanEntity(spanData: AiSdkSpanData): SpanEntity {
|
||||
const { span, topicId, modelName } = spanData
|
||||
const spanContext = span.spanContext()
|
||||
|
||||
// 尝试从不同方式获取span数据
|
||||
const spanWithInternals = span as SpanWithInternals
|
||||
let attributes: Record<string, any> = {}
|
||||
let events: any[] = []
|
||||
let spanName = 'unknown'
|
||||
let spanStatus = { code: SpanStatusCode.UNSET }
|
||||
let spanKind = SpanKind.INTERNAL
|
||||
let startTime: [number, number] = [0, 0]
|
||||
let endTime: [number, number] | null = null
|
||||
let ended = false
|
||||
let parentSpanId = ''
|
||||
let links: any[] = []
|
||||
|
||||
// 详细记录span的结构信息用于调试
|
||||
logger.debug('Debugging span structure', {
|
||||
hasInternalAttributes: !!spanWithInternals._attributes,
|
||||
hasGetAttributes: typeof (span as any).getAttributes === 'function',
|
||||
spanKeys: Object.keys(span),
|
||||
spanInternalKeys: Object.keys(spanWithInternals),
|
||||
spanContext: span.spanContext(),
|
||||
// 尝试获取所有可能的属性路径
|
||||
attributesPath1: spanWithInternals._attributes,
|
||||
attributesPath2: (span as any).attributes,
|
||||
attributesPath3: (span as any)._spanData?.attributes,
|
||||
attributesPath4: (span as any).resource?.attributes
|
||||
})
|
||||
|
||||
// 尝试多种方式获取attributes
|
||||
if (spanWithInternals._attributes) {
|
||||
attributes = spanWithInternals._attributes
|
||||
logger.debug('Found attributes via _attributes', { attributeCount: Object.keys(attributes).length })
|
||||
} else if (typeof (span as any).getAttributes === 'function') {
|
||||
attributes = (span as any).getAttributes()
|
||||
logger.debug('Found attributes via getAttributes()', { attributeCount: Object.keys(attributes).length })
|
||||
} else if ((span as any).attributes) {
|
||||
attributes = (span as any).attributes
|
||||
logger.debug('Found attributes via direct attributes property', {
|
||||
attributeCount: Object.keys(attributes).length
|
||||
})
|
||||
} else if ((span as any)._spanData?.attributes) {
|
||||
attributes = (span as any)._spanData.attributes
|
||||
logger.debug('Found attributes via _spanData.attributes', { attributeCount: Object.keys(attributes).length })
|
||||
} else {
|
||||
// 尝试从span的其他属性获取
|
||||
logger.warn('无法获取span attributes,尝试备用方法', {
|
||||
availableKeys: Object.keys(span),
|
||||
spanType: span.constructor.name
|
||||
})
|
||||
}
|
||||
|
||||
// 获取其他属性
|
||||
if (spanWithInternals._events) {
|
||||
events = spanWithInternals._events
|
||||
}
|
||||
if (spanWithInternals.name) {
|
||||
spanName = spanWithInternals.name
|
||||
}
|
||||
if (spanWithInternals.status) {
|
||||
spanStatus = spanWithInternals.status
|
||||
}
|
||||
if (spanWithInternals.kind !== undefined) {
|
||||
spanKind = spanWithInternals.kind
|
||||
}
|
||||
if (spanWithInternals.startTime) {
|
||||
startTime = spanWithInternals.startTime
|
||||
}
|
||||
if (spanWithInternals.endTime) {
|
||||
endTime = spanWithInternals.endTime
|
||||
}
|
||||
if (spanWithInternals.ended !== undefined) {
|
||||
ended = spanWithInternals.ended
|
||||
}
|
||||
if (spanWithInternals.parentSpanId) {
|
||||
parentSpanId = spanWithInternals.parentSpanId
|
||||
}
|
||||
// 兜底:尝试从 attributes 中读取我们注入的父信息
|
||||
if (!parentSpanId && attributes['trace.parentSpanId']) {
|
||||
parentSpanId = attributes['trace.parentSpanId']
|
||||
}
|
||||
if (spanWithInternals.links) {
|
||||
links = spanWithInternals.links
|
||||
}
|
||||
|
||||
// 提取 AI SDK 特有的数据
|
||||
const tokenUsage = this.extractTokenUsage(attributes)
|
||||
const { inputs, outputs } = this.extractInputsOutputs(attributes)
|
||||
const formattedSpanName = this.formatSpanName(spanName)
|
||||
const spanTag = this.extractSpanTag(spanName, attributes)
|
||||
const typeSpecificData = this.extractSpanTypeSpecificData(attributes)
|
||||
|
||||
// 详细记录转换过程
|
||||
const operationId = attributes['ai.operationId']
|
||||
logger.debug('Converting AI SDK span to SpanEntity', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
spanTag,
|
||||
hasTokenUsage: !!tokenUsage,
|
||||
hasInputs: !!inputs,
|
||||
hasOutputs: !!outputs,
|
||||
hasTypeSpecificData: Object.keys(typeSpecificData).length > 0,
|
||||
attributesCount: Object.keys(attributes).length,
|
||||
topicId,
|
||||
modelName,
|
||||
spanId: spanContext.spanId,
|
||||
traceId: spanContext.traceId
|
||||
})
|
||||
|
||||
if (tokenUsage) {
|
||||
logger.debug('Token usage data found', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
usage: tokenUsage,
|
||||
spanId: spanContext.spanId
|
||||
})
|
||||
}
|
||||
|
||||
if (inputs || outputs) {
|
||||
logger.debug('Input/Output data extracted', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
hasInputs: !!inputs,
|
||||
hasOutputs: !!outputs,
|
||||
inputKeys: inputs ? Object.keys(inputs) : [],
|
||||
outputKeys: outputs ? Object.keys(outputs) : [],
|
||||
spanId: spanContext.spanId
|
||||
})
|
||||
}
|
||||
|
||||
if (Object.keys(typeSpecificData).length > 0) {
|
||||
logger.debug('Type-specific data extracted', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
typeSpecificKeys: Object.keys(typeSpecificData),
|
||||
spanId: spanContext.spanId
|
||||
})
|
||||
}
|
||||
|
||||
// 转换为 SpanEntity 格式
|
||||
const spanEntity: SpanEntity = {
|
||||
id: spanContext.spanId,
|
||||
name: formattedSpanName,
|
||||
parentId: parentSpanId,
|
||||
traceId: spanContext.traceId,
|
||||
status: this.convertSpanStatus(spanStatus.code),
|
||||
kind: this.convertSpanKind(spanKind),
|
||||
attributes: {
|
||||
...this.filterRelevantAttributes(attributes),
|
||||
...typeSpecificData,
|
||||
inputs: inputs,
|
||||
outputs: outputs,
|
||||
tags: spanTag,
|
||||
modelName: modelName || this.extractModelFromAttributes(attributes) || ''
|
||||
},
|
||||
isEnd: ended,
|
||||
events: events,
|
||||
startTime: this.convertTimestamp(startTime),
|
||||
endTime: endTime ? this.convertTimestamp(endTime) : null,
|
||||
links: links,
|
||||
topicId: topicId,
|
||||
usage: tokenUsage,
|
||||
modelName: modelName || this.extractModelFromAttributes(attributes)
|
||||
}
|
||||
|
||||
logger.debug('AI SDK span successfully converted to SpanEntity', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
spanId: spanContext.spanId,
|
||||
traceId: spanContext.traceId,
|
||||
topicId,
|
||||
entityId: spanEntity.id,
|
||||
hasUsage: !!spanEntity.usage,
|
||||
status: spanEntity.status,
|
||||
tags: spanEntity.attributes?.tags
|
||||
})
|
||||
|
||||
return spanEntity
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 AI SDK attributes 中提取 token usage
|
||||
* 支持多种格式:
|
||||
* - AI SDK 标准格式: ai.usage.completionTokens, ai.usage.promptTokens
|
||||
* - 完整usage对象格式: ai.usage (JSON字符串或对象)
|
||||
*/
|
||||
private static extractTokenUsage(attributes: Record<string, any>): TokenUsage | undefined {
|
||||
logger.debug('Extracting token usage from attributes', {
|
||||
attributeKeys: Object.keys(attributes),
|
||||
usageRelatedKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token')),
|
||||
fullAttributes: attributes
|
||||
})
|
||||
|
||||
const inputsTokenKeys = [
|
||||
// base span
|
||||
'ai.usage.promptTokens',
|
||||
// LLM span
|
||||
'gen_ai.usage.input_tokens'
|
||||
]
|
||||
const outputTokenKeys = [
|
||||
// base span
|
||||
'ai.usage.completionTokens',
|
||||
// LLM span
|
||||
'gen_ai.usage.output_tokens'
|
||||
]
|
||||
|
||||
const promptTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
|
||||
const completionTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
|
||||
|
||||
if (completionTokens !== undefined || promptTokens !== undefined) {
|
||||
const usage: TokenUsage = {
|
||||
prompt_tokens: Number(promptTokens) || 0,
|
||||
completion_tokens: Number(completionTokens) || 0,
|
||||
total_tokens: (Number(promptTokens) || 0) + (Number(completionTokens) || 0)
|
||||
}
|
||||
|
||||
logger.debug('Extracted token usage from AI SDK standard attributes', {
|
||||
usage,
|
||||
foundStandardAttributes: {
|
||||
'ai.usage.completionTokens': completionTokens,
|
||||
'ai.usage.promptTokens': promptTokens
|
||||
}
|
||||
})
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
// 对于不包含token usage的spans(如tool calls),这是正常的
|
||||
logger.debug('No token usage found in span attributes (normal for tool calls)', {
|
||||
availableKeys: Object.keys(attributes),
|
||||
usageKeys: Object.keys(attributes).filter((key) => key.includes('usage') || key.includes('token'))
|
||||
})
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 AI SDK attributes 中提取 inputs 和 outputs
|
||||
* 根据AI SDK文档按不同span类型精确映射
|
||||
*/
|
||||
private static extractInputsOutputs(attributes: Record<string, any>): { inputs?: any; outputs?: any } {
|
||||
const operationId = attributes['ai.operationId']
|
||||
let inputs: any = undefined
|
||||
let outputs: any = undefined
|
||||
|
||||
logger.debug('Extracting inputs/outputs by operation type', {
|
||||
operationId,
|
||||
availableKeys: Object.keys(attributes).filter(
|
||||
(key) => key.includes('prompt') || key.includes('response') || key.includes('toolCall')
|
||||
)
|
||||
})
|
||||
|
||||
// 根据AI SDK文档按操作类型提取数据
|
||||
switch (operationId) {
|
||||
case 'ai.generateText':
|
||||
case 'ai.streamText':
|
||||
// 顶层LLM spans: ai.prompt 包含输入
|
||||
inputs = {
|
||||
prompt: this.parseAttributeValue(attributes['ai.prompt'])
|
||||
}
|
||||
outputs = this.extractLLMOutputs(attributes)
|
||||
break
|
||||
|
||||
case 'ai.generateText.doGenerate':
|
||||
case 'ai.streamText.doStream':
|
||||
// Provider spans: ai.prompt.messages 包含详细输入
|
||||
inputs = {
|
||||
messages: this.parseAttributeValue(attributes['ai.prompt.messages']),
|
||||
tools: this.parseAttributeValue(attributes['ai.prompt.tools']),
|
||||
toolChoice: this.parseAttributeValue(attributes['ai.prompt.toolChoice'])
|
||||
}
|
||||
outputs = this.extractProviderOutputs(attributes)
|
||||
break
|
||||
|
||||
case 'ai.toolCall':
|
||||
// Tool call spans: 工具参数和结果
|
||||
inputs = {
|
||||
toolName: attributes['ai.toolCall.name'],
|
||||
toolId: attributes['ai.toolCall.id'],
|
||||
args: this.parseAttributeValue(attributes['ai.toolCall.args'])
|
||||
}
|
||||
outputs = {
|
||||
result: this.parseAttributeValue(attributes['ai.toolCall.result'])
|
||||
}
|
||||
break
|
||||
|
||||
default:
|
||||
// 回退到通用逻辑
|
||||
inputs = this.extractGenericInputs(attributes)
|
||||
outputs = this.extractGenericOutputs(attributes)
|
||||
break
|
||||
}
|
||||
|
||||
logger.debug('Extracted inputs/outputs', {
|
||||
operationId,
|
||||
hasInputs: !!inputs,
|
||||
hasOutputs: !!outputs,
|
||||
inputKeys: inputs ? Object.keys(inputs) : [],
|
||||
outputKeys: outputs ? Object.keys(outputs) : []
|
||||
})
|
||||
|
||||
return { inputs, outputs }
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取LLM顶层spans的输出
|
||||
*/
|
||||
private static extractLLMOutputs(attributes: Record<string, any>): any {
|
||||
const outputs: any = {}
|
||||
|
||||
if (attributes['ai.response.text']) {
|
||||
outputs.text = attributes['ai.response.text']
|
||||
}
|
||||
if (attributes['ai.response.toolCalls']) {
|
||||
outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls'])
|
||||
}
|
||||
if (attributes['ai.response.finishReason']) {
|
||||
outputs.finishReason = attributes['ai.response.finishReason']
|
||||
}
|
||||
if (attributes['ai.settings.maxOutputTokens']) {
|
||||
outputs.maxOutputTokens = attributes['ai.settings.maxOutputTokens']
|
||||
}
|
||||
|
||||
return Object.keys(outputs).length > 0 ? outputs : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取Provider spans的输出
|
||||
*/
|
||||
private static extractProviderOutputs(attributes: Record<string, any>): any {
|
||||
const outputs: any = {}
|
||||
|
||||
if (attributes['ai.response.text']) {
|
||||
outputs.text = attributes['ai.response.text']
|
||||
}
|
||||
if (attributes['ai.response.toolCalls']) {
|
||||
outputs.toolCalls = this.parseAttributeValue(attributes['ai.response.toolCalls'])
|
||||
}
|
||||
if (attributes['ai.response.finishReason']) {
|
||||
outputs.finishReason = attributes['ai.response.finishReason']
|
||||
}
|
||||
|
||||
// doStream特有的性能指标
|
||||
if (attributes['ai.response.msToFirstChunk']) {
|
||||
outputs.msToFirstChunk = attributes['ai.response.msToFirstChunk']
|
||||
}
|
||||
if (attributes['ai.response.msToFinish']) {
|
||||
outputs.msToFinish = attributes['ai.response.msToFinish']
|
||||
}
|
||||
if (attributes['ai.response.avgCompletionTokensPerSecond']) {
|
||||
outputs.avgCompletionTokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond']
|
||||
}
|
||||
|
||||
return Object.keys(outputs).length > 0 ? outputs : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用输入提取(回退逻辑)
|
||||
*/
|
||||
private static extractGenericInputs(attributes: Record<string, any>): any {
|
||||
const inputKeys = ['ai.prompt', 'ai.prompt.messages', 'ai.request', 'inputs']
|
||||
|
||||
for (const key of inputKeys) {
|
||||
if (attributes[key]) {
|
||||
return this.parseAttributeValue(attributes[key])
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用输出提取(回退逻辑)
|
||||
*/
|
||||
private static extractGenericOutputs(attributes: Record<string, any>): any {
|
||||
const outputKeys = ['ai.response.text', 'ai.response', 'ai.output', 'outputs']
|
||||
|
||||
for (const key of outputKeys) {
|
||||
if (attributes[key]) {
|
||||
return this.parseAttributeValue(attributes[key])
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析属性值,处理字符串化的 JSON
|
||||
*/
|
||||
private static parseAttributeValue(value: any): any {
|
||||
if (typeof value === 'string') {
|
||||
try {
|
||||
return JSON.parse(value)
|
||||
} catch (e) {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化 span 名称,处理 AI SDK 的层级结构
|
||||
*/
|
||||
private static formatSpanName(name: string): string {
|
||||
// AI SDK 的 span 名称可能是 ai.generateText, ai.streamText.doStream 等
|
||||
// 保持原始名称,但可以添加一些格式化逻辑
|
||||
if (name.startsWith('ai.')) {
|
||||
return name
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
/**
|
||||
* 从AI SDK operationId中提取精确的span标签
|
||||
*/
|
||||
private static extractSpanTag(name: string, attributes: Record<string, any>): string {
|
||||
const operationId = attributes['ai.operationId']
|
||||
|
||||
logger.debug('Extracting span tag', {
|
||||
spanName: name,
|
||||
operationId,
|
||||
operationName: attributes['operation.name']
|
||||
})
|
||||
|
||||
// 根据AI SDK文档的operationId精确分类
|
||||
switch (operationId) {
|
||||
case 'ai.generateText':
|
||||
return 'LLM-GENERATE'
|
||||
case 'ai.streamText':
|
||||
return 'LLM-STREAM'
|
||||
case 'ai.generateText.doGenerate':
|
||||
return 'PROVIDER-GENERATE'
|
||||
case 'ai.streamText.doStream':
|
||||
return 'PROVIDER-STREAM'
|
||||
case 'ai.toolCall':
|
||||
return 'TOOL-CALL'
|
||||
case 'ai.generateImage':
|
||||
return 'IMAGE'
|
||||
case 'ai.embed':
|
||||
return 'EMBEDDING'
|
||||
default:
|
||||
// 回退逻辑:基于span名称
|
||||
if (name.includes('generateText') || name.includes('streamText')) {
|
||||
return 'LLM'
|
||||
}
|
||||
if (name.includes('generateImage')) {
|
||||
return 'IMAGE'
|
||||
}
|
||||
if (name.includes('embed')) {
|
||||
return 'EMBEDDING'
|
||||
}
|
||||
if (name.includes('toolCall')) {
|
||||
return 'TOOL'
|
||||
}
|
||||
|
||||
// 最终回退
|
||||
return attributes['ai.operationType'] || attributes['operation.type'] || 'AI_SDK'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据span类型提取特定的额外数据
|
||||
*/
|
||||
private static extractSpanTypeSpecificData(attributes: Record<string, any>): Record<string, any> {
|
||||
const operationId = attributes['ai.operationId']
|
||||
const specificData: Record<string, any> = {}
|
||||
|
||||
switch (operationId) {
|
||||
case 'ai.generateText':
|
||||
case 'ai.streamText':
|
||||
// LLM顶层spans的特定数据
|
||||
if (attributes['ai.settings.maxOutputTokens']) {
|
||||
specificData.maxOutputTokens = attributes['ai.settings.maxOutputTokens']
|
||||
}
|
||||
if (attributes['resource.name']) {
|
||||
specificData.functionId = attributes['resource.name']
|
||||
}
|
||||
break
|
||||
|
||||
case 'ai.generateText.doGenerate':
|
||||
case 'ai.streamText.doStream':
|
||||
// Provider spans的特定数据
|
||||
if (attributes['ai.model.id']) {
|
||||
specificData.providerId = attributes['ai.model.provider'] || 'unknown'
|
||||
specificData.modelId = attributes['ai.model.id']
|
||||
}
|
||||
|
||||
// doStream特有的性能数据
|
||||
if (operationId === 'ai.streamText.doStream') {
|
||||
if (attributes['ai.response.msToFirstChunk']) {
|
||||
specificData.msToFirstChunk = attributes['ai.response.msToFirstChunk']
|
||||
}
|
||||
if (attributes['ai.response.msToFinish']) {
|
||||
specificData.msToFinish = attributes['ai.response.msToFinish']
|
||||
}
|
||||
if (attributes['ai.response.avgCompletionTokensPerSecond']) {
|
||||
specificData.tokensPerSecond = attributes['ai.response.avgCompletionTokensPerSecond']
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
case 'ai.toolCall':
|
||||
// Tool call spans的特定数据
|
||||
specificData.toolName = attributes['ai.toolCall.name']
|
||||
specificData.toolId = attributes['ai.toolCall.id']
|
||||
|
||||
// 根据文档,tool call可能有不同的操作类型
|
||||
if (attributes['operation.name']) {
|
||||
specificData.operationName = attributes['operation.name']
|
||||
}
|
||||
break
|
||||
|
||||
default:
|
||||
// 通用的AI SDK属性
|
||||
if (attributes['ai.telemetry.functionId']) {
|
||||
specificData.telemetryFunctionId = attributes['ai.telemetry.functionId']
|
||||
}
|
||||
if (attributes['ai.telemetry.metadata']) {
|
||||
specificData.telemetryMetadata = this.parseAttributeValue(attributes['ai.telemetry.metadata'])
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 添加通用的操作标识
|
||||
if (operationId) {
|
||||
specificData.operationType = operationId
|
||||
}
|
||||
if (attributes['operation.name']) {
|
||||
specificData.operationName = attributes['operation.name']
|
||||
}
|
||||
|
||||
logger.debug('Extracted type-specific data', {
|
||||
operationId,
|
||||
specificDataKeys: Object.keys(specificData),
|
||||
specificData
|
||||
})
|
||||
|
||||
return specificData
|
||||
}
|
||||
|
||||
/**
|
||||
* 从属性中提取模型名称
|
||||
*/
|
||||
private static extractModelFromAttributes(attributes: Record<string, any>): string | undefined {
|
||||
return (
|
||||
attributes['ai.model.id'] ||
|
||||
attributes['ai.model'] ||
|
||||
attributes['model.id'] ||
|
||||
attributes['model'] ||
|
||||
attributes['modelName']
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 过滤相关的属性,移除不需要的系统属性
|
||||
*/
|
||||
private static filterRelevantAttributes(attributes: Record<string, any>): Record<string, any> {
|
||||
const filtered: Record<string, any> = {}
|
||||
|
||||
// 保留有用的属性,过滤掉已经单独处理的属性
|
||||
const excludeKeys = ['ai.usage', 'ai.prompt', 'ai.response', 'ai.input', 'ai.output', 'inputs', 'outputs']
|
||||
|
||||
Object.entries(attributes).forEach(([key, value]) => {
|
||||
if (!excludeKeys.includes(key)) {
|
||||
filtered[key] = value
|
||||
}
|
||||
})
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 span 状态
|
||||
*/
|
||||
private static convertSpanStatus(statusCode?: SpanStatusCode): string {
|
||||
switch (statusCode) {
|
||||
case SpanStatusCode.OK:
|
||||
return 'OK'
|
||||
case SpanStatusCode.ERROR:
|
||||
return 'ERROR'
|
||||
case SpanStatusCode.UNSET:
|
||||
default:
|
||||
return 'UNSET'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 span 类型
|
||||
*/
|
||||
private static convertSpanKind(kind?: SpanKind): string {
|
||||
switch (kind) {
|
||||
case SpanKind.INTERNAL:
|
||||
return 'INTERNAL'
|
||||
case SpanKind.CLIENT:
|
||||
return 'CLIENT'
|
||||
case SpanKind.SERVER:
|
||||
return 'SERVER'
|
||||
case SpanKind.PRODUCER:
|
||||
return 'PRODUCER'
|
||||
case SpanKind.CONSUMER:
|
||||
return 'CONSUMER'
|
||||
default:
|
||||
return 'INTERNAL'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换时间戳格式
|
||||
*/
|
||||
private static convertTimestamp(timestamp: [number, number] | number): number {
|
||||
if (Array.isArray(timestamp)) {
|
||||
// OpenTelemetry 高精度时间戳 [seconds, nanoseconds]
|
||||
return timestamp[0] * 1000 + timestamp[1] / 1000000
|
||||
}
|
||||
return timestamp
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { AiSdkSpanAdapter } from '../AiSdkSpanAdapter'
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
describe('AiSdkSpanAdapter', () => {
|
||||
const createMockSpan = (attributes: Record<string, unknown>): Span => {
|
||||
const span = {
|
||||
spanContext: () => ({
|
||||
traceId: 'trace-id',
|
||||
spanId: 'span-id'
|
||||
}),
|
||||
_attributes: attributes,
|
||||
_events: [],
|
||||
name: 'test span',
|
||||
status: { code: SpanStatusCode.OK },
|
||||
kind: SpanKind.CLIENT,
|
||||
startTime: [0, 0] as [number, number],
|
||||
endTime: [0, 1] as [number, number],
|
||||
ended: true,
|
||||
parentSpanId: '',
|
||||
links: []
|
||||
}
|
||||
return span as unknown as Span
|
||||
}
|
||||
|
||||
it('maps prompt and completion usage tokens to the correct fields', () => {
|
||||
const attributes = {
|
||||
'ai.usage.promptTokens': 321,
|
||||
'ai.usage.completionTokens': 654
|
||||
}
|
||||
|
||||
const span = createMockSpan(attributes)
|
||||
const result = AiSdkSpanAdapter.convertToSpanEntity({ span })
|
||||
|
||||
expect(result.usage).toBeDefined()
|
||||
expect(result.usage?.prompt_tokens).toBe(321)
|
||||
expect(result.usage?.completion_tokens).toBe(654)
|
||||
expect(result.usage?.total_tokens).toBe(975)
|
||||
})
|
||||
})
|
||||
@@ -1,156 +0,0 @@
|
||||
/**
|
||||
* Type Tests for Merged Provider Types
|
||||
*
|
||||
* These tests validate that the auto-extraction and merging of provider types works correctly.
|
||||
* They use type-level assertions to ensure compile-time type safety.
|
||||
*/
|
||||
|
||||
import { describe, expectTypeOf, it } from 'vitest'
|
||||
|
||||
import type { AppProviderId, AppProviderSettingsMap } from '../merged'
|
||||
import { appProviderIds } from '../merged'
|
||||
|
||||
describe('Unified Provider Types', () => {
|
||||
describe('appProviderIds literal access', () => {
|
||||
it('should return canonical IDs with literal types', () => {
|
||||
// 别名 → 基础名
|
||||
expectTypeOf(appProviderIds.vertexai).toEqualTypeOf<'google-vertex'>()
|
||||
// 变体 → 自身(自反映射)
|
||||
expectTypeOf(appProviderIds['openai-chat']).toEqualTypeOf<'openai-chat'>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('AppProviderId - All Providers', () => {
|
||||
it('should include all core extension names', () => {
|
||||
type Check1 = 'openai' extends AppProviderId ? true : false
|
||||
type Check2 = 'anthropic' extends AppProviderId ? true : false
|
||||
type Check3 = 'google' extends AppProviderId ? true : false
|
||||
type Check4 = 'azure' extends AppProviderId ? true : false
|
||||
type Check5 = 'deepseek' extends AppProviderId ? true : false
|
||||
type Check6 = 'xai' extends AppProviderId ? true : false
|
||||
|
||||
expectTypeOf<Check1>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check2>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check3>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check4>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check5>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check6>().toEqualTypeOf<true>()
|
||||
})
|
||||
|
||||
it('should include all project extension names', () => {
|
||||
type Check1 = 'google-vertex' extends AppProviderId ? true : false
|
||||
type Check2 = 'bedrock' extends AppProviderId ? true : false
|
||||
type Check3 = 'github-copilot-openai-compatible' extends AppProviderId ? true : false
|
||||
type Check4 = 'perplexity' extends AppProviderId ? true : false
|
||||
type Check5 = 'mistral' extends AppProviderId ? true : false
|
||||
type Check6 = 'huggingface' extends AppProviderId ? true : false
|
||||
type Check7 = 'gateway' extends AppProviderId ? true : false
|
||||
type Check8 = 'cerebras' extends AppProviderId ? true : false
|
||||
type Check9 = 'ollama' extends AppProviderId ? true : false
|
||||
|
||||
expectTypeOf<Check1>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check2>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check3>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check4>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check5>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check6>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check7>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check8>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check9>().toEqualTypeOf<true>()
|
||||
})
|
||||
|
||||
it('should include all aliases (core + project)', () => {
|
||||
// Core aliases
|
||||
type Check2 = 'claude' extends AppProviderId ? true : false
|
||||
|
||||
// Project aliases
|
||||
type Check3 = 'vertexai' extends AppProviderId ? true : false
|
||||
type Check4 = 'aws-bedrock' extends AppProviderId ? true : false
|
||||
type Check5 = 'copilot' extends AppProviderId ? true : false
|
||||
type Check6 = 'github-copilot' extends AppProviderId ? true : false
|
||||
type Check7 = 'hf' extends AppProviderId ? true : false
|
||||
type Check8 = 'hugging-face' extends AppProviderId ? true : false
|
||||
type Check9 = 'ai-gateway' extends AppProviderId ? true : false
|
||||
|
||||
expectTypeOf<Check2>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check3>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check4>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check5>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check6>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check7>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check8>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check9>().toEqualTypeOf<true>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('AppProviderId', () => {
|
||||
it('should merge core and project IDs', () => {
|
||||
// Core providers
|
||||
type Check1 = 'openai' extends AppProviderId ? true : false
|
||||
type Check2 = 'anthropic' extends AppProviderId ? true : false
|
||||
type Check3 = 'google' extends AppProviderId ? true : false
|
||||
type Check4 = 'azure' extends AppProviderId ? true : false
|
||||
type Check5 = 'xai' extends AppProviderId ? true : false
|
||||
|
||||
// Project providers
|
||||
type Check6 = 'google-vertex' extends AppProviderId ? true : false
|
||||
type Check7 = 'bedrock' extends AppProviderId ? true : false
|
||||
type Check8 = 'ollama' extends AppProviderId ? true : false
|
||||
|
||||
expectTypeOf<Check1>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check2>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check3>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check4>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check5>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check6>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check7>().toEqualTypeOf<true>()
|
||||
expectTypeOf<Check8>().toEqualTypeOf<true>()
|
||||
})
|
||||
|
||||
it('should accept string for dynamic providers', () => {
|
||||
type Check = string extends AppProviderId ? true : false
|
||||
expectTypeOf<Check>().toEqualTypeOf<true>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('AppProviderSettingsMap', () => {
|
||||
it('should map core provider IDs to their settings', () => {
|
||||
// OpenAI settings should have OpenAI-specific fields
|
||||
type OpenAISettings = AppProviderSettingsMap['openai']
|
||||
type HasBaseURL = 'baseURL' extends keyof OpenAISettings ? true : false
|
||||
type HasApiKey = 'apiKey' extends keyof OpenAISettings ? true : false
|
||||
|
||||
expectTypeOf<HasBaseURL>().toEqualTypeOf<true>()
|
||||
expectTypeOf<HasApiKey>().toEqualTypeOf<true>()
|
||||
})
|
||||
|
||||
it('should map project provider IDs to their settings', () => {
|
||||
// Project providers should have settings
|
||||
type VertexSettings = AppProviderSettingsMap['google-vertex']
|
||||
type BedrockSettings = AppProviderSettingsMap['bedrock']
|
||||
type OllamaSettings = AppProviderSettingsMap['ollama']
|
||||
|
||||
// These should not be never
|
||||
type VertexNotNever = [VertexSettings] extends [never] ? false : true
|
||||
type BedrockNotNever = [BedrockSettings] extends [never] ? false : true
|
||||
type OllamaNotNever = [OllamaSettings] extends [never] ? false : true
|
||||
|
||||
expectTypeOf<VertexNotNever>().toEqualTypeOf<true>()
|
||||
expectTypeOf<BedrockNotNever>().toEqualTypeOf<true>()
|
||||
expectTypeOf<OllamaNotNever>().toEqualTypeOf<true>()
|
||||
})
|
||||
|
||||
it('should map aliases to same settings as main ID', () => {
|
||||
type OpenRouterByName = AppProviderSettingsMap['openrouter']
|
||||
type OpenRouterByAlias = AppProviderSettingsMap['tokenflux']
|
||||
|
||||
expectTypeOf<OpenRouterByName>().toEqualTypeOf<OpenRouterByAlias>()
|
||||
|
||||
// Vertex AI aliases should have the same settings
|
||||
type VertexByName = AppProviderSettingsMap['google-vertex']
|
||||
type VertexByAlias = AppProviderSettingsMap['vertexai']
|
||||
|
||||
expectTypeOf<VertexByName>().toEqualTypeOf<VertexByAlias>()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,76 +0,0 @@
|
||||
/**
|
||||
* This type definition file is only for renderer.
|
||||
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
|
||||
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
|
||||
* (ai-core package is set as browser-enviroment-only)
|
||||
*
|
||||
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||
*/
|
||||
|
||||
import type { StringKeys } from '@cherrystudio/ai-core/provider'
|
||||
|
||||
import type { AppProviderSettingsMap, AppRuntimeConfig } from './merged'
|
||||
|
||||
/**
|
||||
* Provider 配置
|
||||
* 基于 RuntimeConfig,用于构建 provider 实例的基础配置
|
||||
*/
|
||||
export type ProviderConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> = Omit<
|
||||
AppRuntimeConfig<T>,
|
||||
'plugins' | 'provider'
|
||||
> & {
|
||||
/**
|
||||
* API endpoint path extracted from baseURL
|
||||
* Used for identifying image generation endpoints and other special cases
|
||||
* @example 'chat/completions', 'images/generations', 'predict'
|
||||
*/
|
||||
endpoint?: string
|
||||
}
|
||||
|
||||
export type { AppProviderId, AppProviderSettingsMap, AppRuntimeConfig } from './merged'
|
||||
export { appProviderIds, getAllProviderIds, isRegisteredProviderId } from './merged'
|
||||
/**
|
||||
* Model capability flags computed from model properties and assistant settings.
|
||||
* Used by provider-specific option builders to decide which parameters to include.
|
||||
*/
|
||||
export interface ProviderCapabilities {
|
||||
/**
|
||||
* Whether reasoning/thinking parameters should be sent to the provider.
|
||||
*
|
||||
* True when the model supports reasoning control (thinking token or reasoning effort)
|
||||
* AND the user has configured `reasoning_effort` (not `undefined`),
|
||||
* or when the model is a fixed reasoning model (e.g. DeepSeek R1).
|
||||
*
|
||||
* Note: This can be `true` even when `reasoning_effort` is `'none'` — in that case,
|
||||
* providers should explicitly disable thinking (e.g. Ollama sets `think: false`).
|
||||
*/
|
||||
enableReasoning: boolean
|
||||
|
||||
/**
|
||||
* Whether provider-native web search should be enabled.
|
||||
* True when no external search provider is configured AND the model supports built-in web search.
|
||||
*/
|
||||
enableWebSearch: boolean
|
||||
|
||||
/**
|
||||
* Whether the model should generate images inline.
|
||||
* True when the model supports image generation AND the assistant has it enabled.
|
||||
*/
|
||||
enableGenerateImage: boolean
|
||||
|
||||
/**
|
||||
* Whether provider-native URL context should be enabled.
|
||||
* True when the assistant has it enabled, the provider supports it,
|
||||
* and the model is compatible (currently Gemini or Anthropic).
|
||||
*/
|
||||
enableUrlContext: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of completions operation
|
||||
* Simple interface with getText method to retrieve the generated text
|
||||
*/
|
||||
export type CompletionsResult = {
|
||||
getText: () => string
|
||||
usage?: any
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
/**
|
||||
* Application-Level Provider Type Merge Point
|
||||
*/
|
||||
|
||||
import type { RuntimeConfig } from '@cherrystudio/ai-core/core'
|
||||
import type { ModelConfig } from '@cherrystudio/ai-core/core/models/types'
|
||||
import type { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime'
|
||||
import type {
|
||||
ExtensionConfigToIdResolutionMap,
|
||||
ExtensionToSettingsMap,
|
||||
ExtractProviderIds,
|
||||
ProviderExtensionConfig,
|
||||
StringKeys,
|
||||
UnionToIntersection
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { coreExtensions } from '@cherrystudio/ai-core/provider'
|
||||
|
||||
import { extensions } from '../provider/extensions'
|
||||
|
||||
/**
|
||||
* All provider extensions merged into one array
|
||||
*/
|
||||
const allExtensions = [...coreExtensions, ...extensions] as const
|
||||
|
||||
type AllExtensionConfigs = (typeof allExtensions)[number]['config']
|
||||
|
||||
// ==================== Unified Application Types ====================
|
||||
|
||||
/**
|
||||
* Complete Application Provider ID Type
|
||||
*/
|
||||
type KnownAppProviderId = ExtractProviderIds<AllExtensionConfigs>
|
||||
export type AppProviderId = KnownAppProviderId | (string & {})
|
||||
|
||||
/**
|
||||
* Application Provider Settings Map
|
||||
* 使用 UnionToIntersection 将所有 extension 的 settings map 合并为单一对象类型
|
||||
*/
|
||||
export type AppProviderSettingsMap = UnionToIntersection<ExtensionToSettingsMap<(typeof allExtensions)[number]>>
|
||||
// ==================== Runtime Utilities ====================
|
||||
|
||||
/**
|
||||
* Check if a provider ID belongs to the registered extensions
|
||||
*/
|
||||
export function isRegisteredProviderId(id: string): boolean {
|
||||
return allExtensions.some((ext) => ext.hasProviderId(id))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered provider IDs (for debugging/logging)
|
||||
*/
|
||||
export function getAllProviderIds(): string[] {
|
||||
return allExtensions.flatMap((ext) => ext.getProviderIds())
|
||||
}
|
||||
|
||||
type ProviderIdsMap = UnionToIntersection<ExtensionConfigToIdResolutionMap<AllExtensionConfigs>>
|
||||
|
||||
/**
|
||||
* 应用层 Provider IDs 常量
|
||||
*/
|
||||
function buildAppProviderIds(): ProviderIdsMap {
|
||||
const map = {} as ProviderIdsMap
|
||||
|
||||
allExtensions.forEach((ext) => {
|
||||
const config = ext.config as ProviderExtensionConfig<any, any, KnownAppProviderId>
|
||||
const name = config.name
|
||||
;(map as Record<string, KnownAppProviderId>)[name] = name
|
||||
|
||||
if (config.aliases) {
|
||||
config.aliases.forEach((alias) => {
|
||||
;(map as Record<string, KnownAppProviderId>)[alias] = name
|
||||
})
|
||||
}
|
||||
|
||||
if (config.variants) {
|
||||
config.variants.forEach((variant) => {
|
||||
// 变体自反映射:'azure-responses' -> 'azure-responses'
|
||||
const variantId = `${name}-${variant.suffix}` as KnownAppProviderId
|
||||
;(map as Record<string, KnownAppProviderId>)[variantId] = variantId
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return map
|
||||
}
|
||||
|
||||
export const appProviderIds = buildAppProviderIds()
|
||||
|
||||
export type AppModelConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> =
|
||||
ModelConfig<T, AppProviderSettingsMap>
|
||||
|
||||
/**
|
||||
* 应用层运行时配置 - 支持完整的 App provider IDs 和 settings
|
||||
*/
|
||||
export type AppRuntimeConfig<T extends StringKeys<AppProviderSettingsMap> = StringKeys<AppProviderSettingsMap>> =
|
||||
RuntimeConfig<AppProviderSettingsMap, T>
|
||||
|
||||
/**
|
||||
* 应用层运行时执行器 - 支持完整的 App provider IDs 和 settings
|
||||
*/
|
||||
export type AppRuntimeExecutor = RuntimeExecutor<AppProviderSettingsMap>
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import type { Assistant, Message } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* AI SDK 中间件配置项(用于插件构建)
|
||||
*
|
||||
* 注意:provider 和 model 不在此接口中。
|
||||
* 它们是 AiProvider 的固有属性(构造时确定),
|
||||
* 由 AiProvider 内部注入到 buildPlugins,避免调用方遗漏。
|
||||
*/
|
||||
export interface AiSdkMiddlewareConfig {
|
||||
streamOutput: boolean
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
assistant?: Assistant
|
||||
enableReasoning: boolean
|
||||
isPromptToolUse: boolean
|
||||
isSupportedToolUse: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
enableUrlContext: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
webSearchPluginConfig?: WebSearchPluginConfig
|
||||
urlContextConfig?: Record<string, any>
|
||||
knowledgeRecognition?: 'off' | 'on'
|
||||
mcpMode?: string
|
||||
}
|
||||
@@ -1,656 +0,0 @@
|
||||
/**
|
||||
* extractAiSdkStandardParams Unit Tests
|
||||
* Tests for extracting AI SDK standard parameters from custom parameters
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { extractAiSdkStandardParams } from '../options'
|
||||
|
||||
// Mock logger to prevent errors
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
info: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock settings store
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
// Mock hooks to prevent uuid errors
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn(() => ({}))
|
||||
}))
|
||||
|
||||
// Mock uuid to prevent errors
|
||||
vi.mock('uuid', () => ({
|
||||
v4: vi.fn(() => 'test-uuid')
|
||||
}))
|
||||
|
||||
// Mock AssistantService to prevent uuid errors
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getDefaultAssistant: vi.fn(() => ({
|
||||
id: 'test-assistant',
|
||||
name: 'Test Assistant',
|
||||
settings: {}
|
||||
})),
|
||||
getDefaultTopic: vi.fn(() => ({
|
||||
id: 'test-topic',
|
||||
assistantId: 'test-assistant',
|
||||
createdAt: new Date().toISOString()
|
||||
}))
|
||||
}))
|
||||
|
||||
// Mock provider service
|
||||
vi.mock('@renderer/services/ProviderService', () => ({
|
||||
getProviderById: vi.fn(() => ({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider'
|
||||
}))
|
||||
}))
|
||||
|
||||
// Mock config modules
|
||||
vi.mock('@renderer/config/models', async (importOriginal) => {
|
||||
const actual: any = await importOriginal()
|
||||
return {
|
||||
...actual,
|
||||
isOpenAIModel: vi.fn(() => false),
|
||||
isQwenMTModel: vi.fn(() => false),
|
||||
isSupportFlexServiceTierModel: vi.fn(() => false),
|
||||
isSupportVerbosityModel: vi.fn(() => false),
|
||||
getModelSupportedVerbosity: vi.fn(() => [])
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@renderer/config/translate', () => ({
|
||||
mapLanguageToQwenMTModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', () => ({
|
||||
isSupportServiceTierProvider: vi.fn(() => false),
|
||||
isSupportVerbosityProvider: vi.fn(() => false)
|
||||
}))
|
||||
|
||||
describe('extractAiSdkStandardParams', () => {
|
||||
describe('Positive cases - Standard parameters extraction', () => {
|
||||
it('should extract all AI SDK standard parameters', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 1000,
|
||||
temperature: 0.7,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
presencePenalty: 0.5,
|
||||
frequencyPenalty: 0.3,
|
||||
stopSequences: ['STOP', 'END'],
|
||||
seed: 42
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 1000,
|
||||
temperature: 0.7,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
presencePenalty: 0.5,
|
||||
frequencyPenalty: 0.3,
|
||||
stopSequences: ['STOP', 'END'],
|
||||
seed: 42
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract single standard parameter', () => {
|
||||
const customParams = {
|
||||
temperature: 0.8
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.8
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract topK parameter', () => {
|
||||
const customParams = {
|
||||
topK: 50
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topK: 50
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract frequencyPenalty parameter', () => {
|
||||
const customParams = {
|
||||
frequencyPenalty: 0.6
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
frequencyPenalty: 0.6
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract presencePenalty parameter', () => {
|
||||
const customParams = {
|
||||
presencePenalty: 0.4
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
presencePenalty: 0.4
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract stopSequences parameter', () => {
|
||||
const customParams = {
|
||||
stopSequences: ['HALT', 'TERMINATE']
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
stopSequences: ['HALT', 'TERMINATE']
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract seed parameter', () => {
|
||||
const customParams = {
|
||||
seed: 12345
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
seed: 12345
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract maxOutputTokens parameter', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 2048
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 2048
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract topP parameter', () => {
|
||||
const customParams = {
|
||||
topP: 0.95
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topP: 0.95
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Negative cases - Provider-specific parameters', () => {
|
||||
it('should place all non-standard parameters in providerParams', () => {
|
||||
const customParams = {
|
||||
customParam: 'value',
|
||||
anotherParam: 123,
|
||||
thirdParam: true
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customParam: 'value',
|
||||
anotherParam: 123,
|
||||
thirdParam: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should place single provider-specific parameter in providerParams', () => {
|
||||
const customParams = {
|
||||
reasoningEffort: 'high'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
reasoningEffort: 'high'
|
||||
})
|
||||
})
|
||||
|
||||
it('should place model-specific parameter in providerParams', () => {
|
||||
const customParams = {
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||
})
|
||||
})
|
||||
|
||||
it('should place serviceTier in providerParams', () => {
|
||||
const customParams = {
|
||||
serviceTier: 'auto'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
serviceTier: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should place textVerbosity in providerParams', () => {
|
||||
const customParams = {
|
||||
textVerbosity: 'high'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
textVerbosity: 'high'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Mixed parameters', () => {
|
||||
it('should correctly separate mixed standard and provider-specific parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
topK: 40,
|
||||
customParam: 'custom_value',
|
||||
reasoningEffort: 'medium',
|
||||
frequencyPenalty: 0.5,
|
||||
seed: 999
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40,
|
||||
frequencyPenalty: 0.5,
|
||||
seed: 999
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customParam: 'custom_value',
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle complex mixed parameters with nested objects', () => {
|
||||
const customParams = {
|
||||
topP: 0.9,
|
||||
presencePenalty: 0.3,
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 },
|
||||
stopSequences: ['STOP'],
|
||||
serviceTier: 'auto',
|
||||
maxOutputTokens: 4096
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topP: 0.9,
|
||||
presencePenalty: 0.3,
|
||||
stopSequences: ['STOP'],
|
||||
maxOutputTokens: 4096
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 },
|
||||
serviceTier: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle all standard params with some provider params', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 2000,
|
||||
temperature: 0.8,
|
||||
topP: 0.95,
|
||||
topK: 50,
|
||||
presencePenalty: 0.6,
|
||||
frequencyPenalty: 0.4,
|
||||
stopSequences: ['END', 'DONE'],
|
||||
seed: 777,
|
||||
customApiParam: 'value',
|
||||
anotherCustomParam: 123
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 2000,
|
||||
temperature: 0.8,
|
||||
topP: 0.95,
|
||||
topK: 50,
|
||||
presencePenalty: 0.6,
|
||||
frequencyPenalty: 0.4,
|
||||
stopSequences: ['END', 'DONE'],
|
||||
seed: 777
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customApiParam: 'value',
|
||||
anotherCustomParam: 123
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle empty object', () => {
|
||||
const customParams = {}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle zero values for numeric parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0,
|
||||
topK: 0,
|
||||
seed: 0
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0,
|
||||
topK: 0,
|
||||
seed: 0
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle negative values for numeric parameters', () => {
|
||||
const customParams = {
|
||||
presencePenalty: -0.5,
|
||||
frequencyPenalty: -0.3,
|
||||
seed: -1
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
presencePenalty: -0.5,
|
||||
frequencyPenalty: -0.3,
|
||||
seed: -1
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle empty arrays for stopSequences', () => {
|
||||
const customParams = {
|
||||
stopSequences: []
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
stopSequences: []
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle null values in mixed parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customNull: null,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customNull: null
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined values in mixed parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customUndefined: undefined,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customUndefined: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle boolean values for standard parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customBoolean: false,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customBoolean: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle very large numeric values', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 999999,
|
||||
seed: 2147483647,
|
||||
topK: 10000
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 999999,
|
||||
seed: 2147483647,
|
||||
topK: 10000
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle decimal values with high precision', () => {
|
||||
const customParams = {
|
||||
temperature: 0.123456789,
|
||||
topP: 0.987654321,
|
||||
presencePenalty: 0.111111111
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.123456789,
|
||||
topP: 0.987654321,
|
||||
presencePenalty: 0.111111111
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Case sensitivity', () => {
|
||||
it('should NOT extract parameters with incorrect case - uppercase first letter', () => {
|
||||
const customParams = {
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
FrequencyPenalty: 0.5
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
FrequencyPenalty: 0.5
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract parameters with incorrect case - all uppercase', () => {
|
||||
const customParams = {
|
||||
TEMPERATURE: 0.7,
|
||||
TOPK: 40,
|
||||
SEED: 42
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
TEMPERATURE: 0.7,
|
||||
TOPK: 40,
|
||||
SEED: 42
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract parameters with incorrect case - all lowercase', () => {
|
||||
const customParams = {
|
||||
maxoutputtokens: 1000,
|
||||
frequencypenalty: 0.5,
|
||||
stopsequences: ['STOP']
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
maxoutputtokens: 1000,
|
||||
frequencypenalty: 0.5,
|
||||
stopsequences: ['STOP']
|
||||
})
|
||||
})
|
||||
|
||||
it('should correctly extract exact case match while rejecting incorrect case', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
Temperature: 0.8,
|
||||
TEMPERATURE: 0.9,
|
||||
topK: 40,
|
||||
TopK: 50
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
Temperature: 0.8,
|
||||
TEMPERATURE: 0.9,
|
||||
TopK: 50
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parameter name variations', () => {
|
||||
it('should NOT extract similar but incorrect parameter names', () => {
|
||||
const customParams = {
|
||||
temp: 0.7, // should not match temperature
|
||||
top_k: 40, // should not match topK
|
||||
max_tokens: 1000, // should not match maxOutputTokens
|
||||
freq_penalty: 0.5 // should not match frequencyPenalty
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
temp: 0.7,
|
||||
top_k: 40,
|
||||
max_tokens: 1000,
|
||||
freq_penalty: 0.5
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract snake_case versions of standard parameters', () => {
|
||||
const customParams = {
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.3,
|
||||
stop_sequences: ['STOP'],
|
||||
max_output_tokens: 1000
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.3,
|
||||
stop_sequences: ['STOP'],
|
||||
max_output_tokens: 1000
|
||||
})
|
||||
})
|
||||
|
||||
it('should extract exact camelCase parameters only', () => {
|
||||
const customParams = {
|
||||
topK: 40, // correct
|
||||
top_k: 50, // incorrect
|
||||
topP: 0.9, // correct
|
||||
top_p: 0.8, // incorrect
|
||||
frequencyPenalty: 0.5, // correct
|
||||
frequency_penalty: 0.4 // incorrect
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
frequencyPenalty: 0.5
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
top_k: 50,
|
||||
top_p: 0.8,
|
||||
frequency_penalty: 0.4
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,28 +0,0 @@
|
||||
/**
|
||||
* image.ts Unit Tests
|
||||
* Tests for Gemini image generation utilities
|
||||
*/
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { buildGeminiGenerateImageParams } from '../image'
|
||||
|
||||
describe('image utils', () => {
|
||||
describe('buildGeminiGenerateImageParams', () => {
|
||||
it('should return correct response modalities', () => {
|
||||
const result = buildGeminiGenerateImageParams()
|
||||
|
||||
expect(result).toEqual({
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
})
|
||||
})
|
||||
|
||||
it('should return an object with responseModalities property', () => {
|
||||
const result = buildGeminiGenerateImageParams()
|
||||
|
||||
expect(result).toHaveProperty('responseModalities')
|
||||
expect(Array.isArray(result.responseModalities)).toBe(true)
|
||||
expect(result.responseModalities).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,573 +0,0 @@
|
||||
/**
|
||||
* mcp.ts Unit Tests
|
||||
* Tests for MCP tools configuration and conversion utilities
|
||||
*/
|
||||
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import type { Tool } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { convertMcpToolsToAiSdkTools, hasMultimodalContent, mcpResultToTextSummary, setupToolsConfig } from '../mcp'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
info: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/mcp-tools', () => ({
|
||||
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
|
||||
isToolAutoApproved: vi.fn(() => false),
|
||||
callMCPTool: vi.fn(async () => ({
|
||||
content: [{ type: 'text', text: 'Tool executed successfully' }],
|
||||
isError: false
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/userConfirmation', () => ({
|
||||
requestToolConfirmation: vi.fn(async () => true),
|
||||
sendToolApprovalNotification: vi.fn(),
|
||||
setToolIdToNameMapping: vi.fn(),
|
||||
confirmSameNameTools: vi.fn()
|
||||
}))
|
||||
|
||||
describe('mcp utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('setupToolsConfig', () => {
|
||||
it('should return undefined when no MCP tools provided', () => {
|
||||
const result = setupToolsConfig()
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined when empty MCP tools array provided', () => {
|
||||
const result = setupToolsConfig([])
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should convert MCP tools to AI SDK tools format', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'test-tool-1',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'test-tool',
|
||||
description: 'A test tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = setupToolsConfig(mcpTools)
|
||||
|
||||
expect(result).not.toBeUndefined()
|
||||
// Tools are now keyed by id (which includes serverId suffix) for uniqueness
|
||||
expect(Object.keys(result!)).toEqual(['test-tool-1'])
|
||||
expect(result!['test-tool-1']).toHaveProperty('description')
|
||||
expect(result!['test-tool-1']).toHaveProperty('inputSchema')
|
||||
expect(result!['test-tool-1']).toHaveProperty('execute')
|
||||
})
|
||||
|
||||
it('should handle multiple MCP tools', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'tool1-id',
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
name: 'tool1',
|
||||
description: 'First tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'tool2-id',
|
||||
serverId: 'server2',
|
||||
serverName: 'server2',
|
||||
name: 'tool2',
|
||||
description: 'Second tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = setupToolsConfig(mcpTools)
|
||||
|
||||
expect(result).not.toBeUndefined()
|
||||
expect(Object.keys(result!)).toHaveLength(2)
|
||||
// Tools are keyed by id for uniqueness
|
||||
expect(Object.keys(result!)).toEqual(['tool1-id', 'tool2-id'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertMcpToolsToAiSdkTools', () => {
|
||||
it('should convert single MCP tool to AI SDK tool', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'get-weather-id',
|
||||
serverId: 'weather-server',
|
||||
serverName: 'weather-server',
|
||||
name: 'get-weather',
|
||||
description: 'Get weather information',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string' }
|
||||
},
|
||||
required: ['location']
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
// Tools are keyed by id for uniqueness when multiple server instances exist
|
||||
expect(Object.keys(result)).toEqual(['get-weather-id'])
|
||||
|
||||
const tool = result['get-weather-id'] as Tool
|
||||
expect(tool.description).toBe('Get weather information')
|
||||
expect(tool.inputSchema).toBeDefined()
|
||||
expect(typeof tool.execute).toBe('function')
|
||||
})
|
||||
|
||||
it('should handle tool without description', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'no-desc-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'no-desc-tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result)).toEqual(['no-desc-tool-id'])
|
||||
const tool = result['no-desc-tool-id'] as Tool
|
||||
expect(tool.description).toBe('Tool from test-server')
|
||||
})
|
||||
|
||||
it('should convert empty tools array', () => {
|
||||
const result = convertMcpToolsToAiSdkTools([])
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should handle complex input schemas', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'complex-tool-id',
|
||||
serverId: 'server',
|
||||
serverName: 'server',
|
||||
name: 'complex-tool',
|
||||
description: 'Tool with complex schema',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
age: { type: 'number' },
|
||||
tags: {
|
||||
type: 'array',
|
||||
items: { type: 'string' }
|
||||
},
|
||||
metadata: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
key: { type: 'string' }
|
||||
}
|
||||
}
|
||||
},
|
||||
required: ['name']
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result)).toEqual(['complex-tool-id'])
|
||||
const tool = result['complex-tool-id'] as Tool
|
||||
expect(tool.inputSchema).toBeDefined()
|
||||
expect(typeof tool.execute).toBe('function')
|
||||
})
|
||||
|
||||
it('should preserve tool id with special characters', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'special-tool-id',
|
||||
serverId: 'server',
|
||||
serverName: 'server',
|
||||
name: 'tool_with-special.chars',
|
||||
description: 'Special chars tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
// Tools are keyed by id for uniqueness
|
||||
expect(Object.keys(result)).toEqual(['special-tool-id'])
|
||||
})
|
||||
|
||||
it('should handle multiple tools with different schemas', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'string-tool-id',
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
name: 'string-tool',
|
||||
description: 'String tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' }
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'number-tool-id',
|
||||
serverId: 'server2',
|
||||
serverName: 'server2',
|
||||
name: 'number-tool',
|
||||
description: 'Number tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
count: { type: 'number' }
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'boolean-tool-id',
|
||||
serverId: 'server3',
|
||||
serverName: 'server3',
|
||||
name: 'boolean-tool',
|
||||
description: 'Boolean tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
enabled: { type: 'boolean' }
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
// Tools are keyed by id for uniqueness
|
||||
expect(Object.keys(result).sort()).toEqual(['boolean-tool-id', 'number-tool-id', 'string-tool-id'])
|
||||
expect(result['string-tool-id']).toBeDefined()
|
||||
expect(result['number-tool-id']).toBeDefined()
|
||||
expect(result['boolean-tool-id']).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasMultimodalContent', () => {
|
||||
it('should return false for pure text content', () => {
|
||||
expect(hasMultimodalContent({ content: [{ type: 'text', text: 'hello' }] })).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for image content', () => {
|
||||
expect(hasMultimodalContent({ content: [{ type: 'image', data: 'base64...', mimeType: 'image/png' }] })).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('should return true for audio content', () => {
|
||||
expect(hasMultimodalContent({ content: [{ type: 'audio', data: 'base64...', mimeType: 'audio/mp3' }] })).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('should return true for resource with blob', () => {
|
||||
expect(
|
||||
hasMultimodalContent({
|
||||
content: [{ type: 'resource', resource: { blob: 'base64...', mimeType: 'image/png', uri: 'file://a.png' } }]
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for resource without blob', () => {
|
||||
expect(
|
||||
hasMultimodalContent({
|
||||
content: [{ type: 'resource', resource: { text: 'plain text', uri: 'file://a.txt' } }]
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for mixed content with at least one multimodal item', () => {
|
||||
expect(
|
||||
hasMultimodalContent({
|
||||
content: [
|
||||
{ type: 'text', text: 'hello' },
|
||||
{ type: 'image', data: 'base64...', mimeType: 'image/png' }
|
||||
]
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for empty content array', () => {
|
||||
expect(hasMultimodalContent({ content: [] })).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for null/undefined result', () => {
|
||||
expect(hasMultimodalContent(null as any)).toBe(false)
|
||||
expect(hasMultimodalContent(undefined as any)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when content is not an array', () => {
|
||||
expect(hasMultimodalContent({ content: 'not-array' } as any)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('mcpResultToTextSummary', () => {
|
||||
it('should extract text from text content', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'text', text: 'hello world' }] })).toBe('hello world')
|
||||
})
|
||||
|
||||
it('should replace image with placeholder', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'image', data: 'base64...', mimeType: 'image/jpeg' }] })).toBe(
|
||||
'[Image: image/jpeg, delivered to user]'
|
||||
)
|
||||
})
|
||||
|
||||
it('should use default mimeType for image without mimeType', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'image', data: 'base64...' }] })).toBe(
|
||||
'[Image: image/png, delivered to user]'
|
||||
)
|
||||
})
|
||||
|
||||
it('should replace audio with placeholder', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'audio', data: 'base64...', mimeType: 'audio/wav' }] })).toBe(
|
||||
'[Audio: audio/wav, delivered to user]'
|
||||
)
|
||||
})
|
||||
|
||||
it('should use default mimeType for audio without mimeType', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'audio', data: 'base64...' }] })).toBe(
|
||||
'[Audio: audio/mp3, delivered to user]'
|
||||
)
|
||||
})
|
||||
|
||||
it('should replace resource with blob with placeholder', () => {
|
||||
expect(
|
||||
mcpResultToTextSummary({
|
||||
content: [
|
||||
{ type: 'resource', resource: { blob: 'base64...', mimeType: 'application/pdf', uri: 'file://doc.pdf' } }
|
||||
]
|
||||
})
|
||||
).toBe('[Resource: application/pdf, uri=file://doc.pdf, delivered to user]')
|
||||
})
|
||||
|
||||
it('should use resource text when no blob', () => {
|
||||
expect(
|
||||
mcpResultToTextSummary({
|
||||
content: [{ type: 'resource', resource: { text: 'resource content', uri: 'file://a.txt' } }]
|
||||
})
|
||||
).toBe('resource content')
|
||||
})
|
||||
|
||||
it('should JSON.stringify unknown content types', () => {
|
||||
const item = { type: 'unknown' as any, foo: 'bar' }
|
||||
expect(mcpResultToTextSummary({ content: [item] })).toBe(JSON.stringify(item))
|
||||
})
|
||||
|
||||
it('should join multiple content parts with newline', () => {
|
||||
const result = mcpResultToTextSummary({
|
||||
content: [
|
||||
{ type: 'text', text: 'Description' },
|
||||
{ type: 'image', data: 'base64...', mimeType: 'image/png' }
|
||||
]
|
||||
})
|
||||
expect(result).toBe('Description\n[Image: image/png, delivered to user]')
|
||||
})
|
||||
|
||||
it('should JSON.stringify result when content is missing', () => {
|
||||
expect(mcpResultToTextSummary(null as any)).toBe('null')
|
||||
expect(mcpResultToTextSummary({} as any)).toBe('{}')
|
||||
})
|
||||
|
||||
it('should handle empty text gracefully', () => {
|
||||
expect(mcpResultToTextSummary({ content: [{ type: 'text' }] })).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('tool execution', () => {
|
||||
it('should execute tool with user confirmation', async () => {
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
isError: false
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'test-exec-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'test-exec-tool',
|
||||
description: 'Test execution tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['test-exec-tool-id'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||
expect(callMCPTool).toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle user cancellation', async () => {
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'cancelled-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'cancelled-tool',
|
||||
description: 'Tool to cancel',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['cancelled-tool-id'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||
expect(callMCPTool).not.toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'User declined to execute tool "cancelled-tool".'
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool execution error', async () => {
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Error occurred' }],
|
||||
isError: true
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'error-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'error-tool',
|
||||
description: 'Tool that errors',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['error-tool-id'] as Tool
|
||||
|
||||
await expect(
|
||||
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
|
||||
).rejects.toEqual({
|
||||
content: [{ type: 'text', text: 'Error occurred' }],
|
||||
isError: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should auto-approve when enabled', async () => {
|
||||
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(isToolAutoApproved).mockReturnValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||
isError: false
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'auto-approve-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'auto-approve-tool',
|
||||
description: 'Auto-approved tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['auto-approve-tool-id'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).not.toHaveBeenCalled()
|
||||
expect(callMCPTool).toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,288 +0,0 @@
|
||||
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getReasoningEffort } from '../reasoning'
|
||||
|
||||
// Mock logger
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
warn: vi.fn(),
|
||||
info: vi.fn(),
|
||||
error: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/assistants', () => {
|
||||
const mockAssistantsSlice = {
|
||||
name: 'assistants',
|
||||
reducer: vi.fn((state = { entities: {}, ids: [] }) => state),
|
||||
actions: {
|
||||
updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' }))
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
default: mockAssistantsSlice.reducer,
|
||||
updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' })),
|
||||
assistantsSlice: mockAssistantsSlice
|
||||
}
|
||||
})
|
||||
|
||||
// Mock provider service
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: (model: Model) => ({
|
||||
id: model.provider,
|
||||
name: 'Poe',
|
||||
type: 'openai'
|
||||
}),
|
||||
getAssistantSettings: (assistant: Assistant) => assistant.settings || {}
|
||||
}))
|
||||
|
||||
describe('Poe Provider Reasoning Support', () => {
|
||||
const createPoeModel = (id: string): Model => ({
|
||||
id,
|
||||
name: id,
|
||||
provider: SystemProviderIds.poe,
|
||||
group: 'poe'
|
||||
})
|
||||
|
||||
const createAssistant = (reasoning_effort?: ReasoningEffortOption, maxTokens?: number): Assistant => ({
|
||||
id: 'test-assistant',
|
||||
name: 'Test Assistant',
|
||||
emoji: '🤖',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {
|
||||
reasoning_effort,
|
||||
maxTokens
|
||||
}
|
||||
})
|
||||
|
||||
describe('GPT-5 Series Models', () => {
|
||||
it('should return reasoning_effort in extra_body for GPT-5 model with low effort', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant('low')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
reasoning_effort: 'low'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return reasoning_effort in extra_body for GPT-5 model with medium effort', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant('medium')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return reasoning_effort in extra_body for GPT-5 model with high effort', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant('high')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should convert auto to medium for GPT-5 model in extra_body', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant('auto')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return reasoning_effort in extra_body for GPT-5.1 model', () => {
|
||||
const model = createPoeModel('gpt-5.1')
|
||||
const assistant = createAssistant('medium')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Claude Models', () => {
|
||||
it('should return thinking_budget in extra_body for Claude 3.7 Sonnet', () => {
|
||||
const model = createPoeModel('claude-3.7-sonnet')
|
||||
const assistant = createAssistant('medium', 4096)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toHaveProperty('extra_body')
|
||||
expect(result.extra_body).toHaveProperty('thinking_budget')
|
||||
expect(typeof result.extra_body?.thinking_budget).toBe('number')
|
||||
expect(result.extra_body?.thinking_budget).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should return thinking_budget in extra_body for Claude Sonnet 4', () => {
|
||||
const model = createPoeModel('claude-sonnet-4')
|
||||
const assistant = createAssistant('high', 8192)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toHaveProperty('extra_body')
|
||||
expect(result.extra_body).toHaveProperty('thinking_budget')
|
||||
expect(typeof result.extra_body?.thinking_budget).toBe('number')
|
||||
})
|
||||
|
||||
it('should calculate thinking_budget based on effort ratio and maxTokens', () => {
|
||||
const model = createPoeModel('claude-3.7-sonnet')
|
||||
const assistant = createAssistant('low', 4096)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Gemini Models', () => {
|
||||
it('should return thinking_budget in extra_body for Gemini 2.5 Flash', () => {
|
||||
const model = createPoeModel('gemini-2.5-flash')
|
||||
const assistant = createAssistant('medium')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toHaveProperty('extra_body')
|
||||
expect(result.extra_body).toHaveProperty('thinking_budget')
|
||||
expect(typeof result.extra_body?.thinking_budget).toBe('number')
|
||||
})
|
||||
|
||||
it('should return thinking_budget in extra_body for Gemini 2.5 Pro', () => {
|
||||
const model = createPoeModel('gemini-2.5-pro')
|
||||
const assistant = createAssistant('high')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toHaveProperty('extra_body')
|
||||
expect(result.extra_body).toHaveProperty('thinking_budget')
|
||||
})
|
||||
|
||||
it('should use -1 for auto effort', () => {
|
||||
const model = createPoeModel('gemini-2.5-flash')
|
||||
const assistant = createAssistant('auto')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result.extra_body?.thinking_budget).toBe(-1)
|
||||
})
|
||||
|
||||
it('should calculate thinking_budget for non-auto effort', () => {
|
||||
const model = createPoeModel('gemini-2.5-flash')
|
||||
const assistant = createAssistant('low')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(typeof result.extra_body?.thinking_budget).toBe('number')
|
||||
})
|
||||
})
|
||||
|
||||
describe('No Reasoning Effort', () => {
|
||||
it('should return empty object when reasoning_effort is not set', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant(undefined)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty object when reasoning_effort is "none"', () => {
|
||||
const model = createPoeModel('gpt-5')
|
||||
const assistant = createAssistant('none')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Non-Reasoning Models', () => {
|
||||
it('should return empty object for non-reasoning models', () => {
|
||||
const model = createPoeModel('gpt-4')
|
||||
const assistant = createAssistant('medium')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases: Models Without Token Limit Configuration', () => {
|
||||
it('should return empty object for Claude models without token limit configuration', () => {
|
||||
const model = createPoeModel('claude-unknown-variant')
|
||||
const assistant = createAssistant('medium', 4096)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
// Should return empty object when token limit is not found
|
||||
expect(result).toEqual({})
|
||||
expect(result.extra_body?.thinking_budget).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return empty object for unmatched Poe reasoning models', () => {
|
||||
// A hypothetical reasoning model that doesn't match GPT-5, Claude, or Gemini
|
||||
const model = createPoeModel('some-reasoning-model')
|
||||
// Make it appear as a reasoning model by giving it a name that won't match known categories
|
||||
const assistant = createAssistant('medium')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
// Should return empty object for unmatched models
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should fallback to -1 for Gemini models without token limit', () => {
|
||||
// Use a Gemini model variant that won't match any token limit pattern
|
||||
// The current regex patterns cover gemini-.*-flash.*$ and gemini-.*-pro.*$
|
||||
// so we need a model that matches isSupportedThinkingTokenGeminiModel but not THINKING_TOKEN_MAP
|
||||
const model = createPoeModel('gemini-2.5-flash')
|
||||
const assistant = createAssistant('auto')
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
// For 'auto' effort, should use -1
|
||||
expect(result.extra_body?.thinking_budget).toBe(-1)
|
||||
})
|
||||
|
||||
it('should enforce minimum 1024 token floor for Claude models', () => {
|
||||
const model = createPoeModel('claude-3.7-sonnet')
|
||||
// Use very small maxTokens to test the minimum floor
|
||||
const assistant = createAssistant('low', 100)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
|
||||
})
|
||||
|
||||
it('should handle undefined maxTokens for Claude models', () => {
|
||||
const model = createPoeModel('claude-3.7-sonnet')
|
||||
const assistant = createAssistant('medium', undefined)
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
|
||||
expect(result).toHaveProperty('extra_body')
|
||||
expect(result.extra_body).toHaveProperty('thinking_budget')
|
||||
expect(typeof result.extra_body?.thinking_budget).toBe('number')
|
||||
expect(result.extra_body?.thinking_budget).toBeGreaterThanOrEqual(1024)
|
||||
})
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,405 +0,0 @@
|
||||
/**
|
||||
* websearch.ts Unit Tests
|
||||
* Tests for web search parameters generation utilities
|
||||
*/
|
||||
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
|
||||
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
|
||||
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
|
||||
}))
|
||||
|
||||
describe('websearch utils', () => {
|
||||
describe('getWebSearchParams', () => {
|
||||
it('should return enhancement params for hunyuan provider', () => {
|
||||
const model: Model = {
|
||||
id: 'hunyuan-model',
|
||||
name: 'Hunyuan Model',
|
||||
provider: 'hunyuan'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
enable_enhancement: true,
|
||||
citation: true,
|
||||
search_info: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should return search params for dashscope provider', () => {
|
||||
const model: Model = {
|
||||
id: 'qwen-model',
|
||||
name: 'Qwen Model',
|
||||
provider: 'dashscope'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return web_search_options for OpenAI web search models', () => {
|
||||
const model: Model = {
|
||||
id: 'o1-pro',
|
||||
name: 'O1 Pro',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
web_search_options: {}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return extra_body with web_search for poe provider', () => {
|
||||
const model: Model = {
|
||||
id: 'Gemini-3-Flash',
|
||||
name: 'Gemini 3 Flash',
|
||||
provider: 'poe'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
web_search: true
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return empty object for other providers', () => {
|
||||
const model: Model = {
|
||||
id: 'gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty object for custom provider', () => {
|
||||
const model: Model = {
|
||||
id: 'custom-model',
|
||||
name: 'Custom Model',
|
||||
provider: 'custom-provider'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildProviderBuiltinWebSearchConfig', () => {
|
||||
const defaultWebSearchConfig: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
describe('openai provider', () => {
|
||||
it('should return low search context size for low maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 20,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'low'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return medium search context size for medium maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return high search context size for high maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 80,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'high'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should use medium for deep research models regardless of maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 100,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-mini',
|
||||
name: 'O3 Mini',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('openai-chat provider', () => {
|
||||
it('should return correct search context size', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
'openai-chat': {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle deep research models', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 100,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-mini',
|
||||
name: 'O3 Mini',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
'openai-chat': {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('anthropic provider', () => {
|
||||
it('should return anthropic search options with maxUses', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 50,
|
||||
blockedDomains: undefined
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should include blockedDomains when excludeDomains provided', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 30,
|
||||
excludeDomains: ['example.com', 'test.com']
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 30,
|
||||
blockedDomains: ['example.com', 'test.com']
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should not include blockedDomains when empty', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 50,
|
||||
blockedDomains: undefined
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('xai provider', () => {
|
||||
it('should return xai-responses search options with enableImageUnderstanding when no excludeDomains', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
'xai-responses': {
|
||||
webSearch: { enableImageUnderstanding: true },
|
||||
xSearch: { enableImageUnderstanding: true }
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should include excludedDomains when excludeDomains provided', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 40,
|
||||
excludeDomains: ['site1.com', 'site2.com']
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
'xai-responses': {
|
||||
webSearch: {
|
||||
enableImageUnderstanding: true,
|
||||
excludedDomains: ['site1.com', 'site2.com']
|
||||
},
|
||||
xSearch: { enableImageUnderstanding: true }
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should limit excluded domains to 5', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 40,
|
||||
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', config)
|
||||
|
||||
expect(result?.['xai-responses']?.webSearch?.excludedDomains).toHaveLength(5)
|
||||
})
|
||||
})
|
||||
|
||||
describe('openrouter provider', () => {
|
||||
it('should return openrouter plugins config', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: 50
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should respect custom maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 75,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: 75
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('unsupported provider', () => {
|
||||
it('should return empty object for unsupported provider', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty object for google provider', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle maxResults at boundary values', () => {
|
||||
// Test boundary at 33 (low/medium)
|
||||
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
|
||||
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
|
||||
expect(result33?.openai?.searchContextSize).toBe('low')
|
||||
|
||||
// Test boundary at 34 (medium)
|
||||
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
|
||||
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
|
||||
expect(result34?.openai?.searchContextSize).toBe('medium')
|
||||
|
||||
// Test boundary at 66 (medium)
|
||||
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
|
||||
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
|
||||
expect(result66?.openai?.searchContextSize).toBe('medium')
|
||||
|
||||
// Test boundary at 67 (high)
|
||||
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
|
||||
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
|
||||
expect(result67?.openai?.searchContextSize).toBe('high')
|
||||
})
|
||||
|
||||
it('should handle zero maxResults', () => {
|
||||
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
expect(result?.openai?.searchContextSize).toBe('low')
|
||||
})
|
||||
|
||||
it('should handle very large maxResults', () => {
|
||||
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
expect(result?.openai?.searchContextSize).toBe('high')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,5 +0,0 @@
|
||||
export function buildGeminiGenerateImageParams(): Record<string, any> {
|
||||
return {
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
}
|
||||
}
|
||||
@@ -1,197 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import store from '@renderer/store'
|
||||
import type { MCPCallToolResponse, MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
|
||||
import {
|
||||
confirmSameNameTools,
|
||||
requestToolConfirmation,
|
||||
sendToolApprovalNotification,
|
||||
setToolIdToNameMapping
|
||||
} from '@renderer/utils/userConfirmation'
|
||||
import { type Tool, type ToolSet } from 'ai'
|
||||
import { jsonSchema, tool } from 'ai'
|
||||
import type { JSONSchema7 } from 'json-schema'
|
||||
|
||||
const logger = loggerService.withContext('MCP-utils')
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
export function setupToolsConfig(
|
||||
mcpTools?: MCPTool[],
|
||||
allowedTools?: string[]
|
||||
): Record<string, Tool<any, any>> | undefined {
|
||||
let tools: ToolSet = {}
|
||||
|
||||
if (!mcpTools?.length) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
tools = convertMcpToolsToAiSdkTools(mcpTools, allowedTools)
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 MCP 工具调用结果是否包含可能携带大体积 base64 数据的多模态内容。
|
||||
* 包括 image、audio 以及含 blob 的 resource 类型。
|
||||
*/
|
||||
export function hasMultimodalContent(result: MCPCallToolResponse): boolean {
|
||||
return (
|
||||
Array.isArray(result?.content) &&
|
||||
result.content.some(
|
||||
(item) => item.type === 'image' || item.type === 'audio' || (item.type === 'resource' && !!item.resource?.blob)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 MCP 工具调用结果转换为纯文本摘要,把图片/音频/resource blob 替换为文本占位描述,
|
||||
* 避免 base64 数据超出消息大小限制(如 kimi 的 4MB 限制)。
|
||||
*/
|
||||
export function mcpResultToTextSummary(result: MCPCallToolResponse): string {
|
||||
if (!result || !result.content || !Array.isArray(result.content)) {
|
||||
return JSON.stringify(result)
|
||||
}
|
||||
|
||||
const parts: string[] = []
|
||||
for (const item of result.content) {
|
||||
switch (item.type) {
|
||||
case 'text':
|
||||
parts.push(item.text || '')
|
||||
break
|
||||
case 'image':
|
||||
parts.push(`[Image: ${item.mimeType || 'image/png'}, delivered to user]`)
|
||||
break
|
||||
case 'audio':
|
||||
parts.push(`[Audio: ${item.mimeType || 'audio/mp3'}, delivered to user]`)
|
||||
break
|
||||
case 'resource':
|
||||
if (item.resource?.blob) {
|
||||
parts.push(
|
||||
`[Resource: ${item.resource.mimeType || 'application/octet-stream'}, uri=${item.resource.uri || 'unknown'}, delivered to user]`
|
||||
)
|
||||
} else {
|
||||
parts.push(item.resource?.text || JSON.stringify(item))
|
||||
}
|
||||
break
|
||||
default:
|
||||
parts.push(JSON.stringify(item))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||
*/
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[], allowedTools?: string[]): ToolSet {
|
||||
const tools: ToolSet = {}
|
||||
|
||||
for (const mcpTool of mcpTools) {
|
||||
// Use mcpTool.id (which includes serverId suffix) to ensure uniqueness
|
||||
// when multiple instances of the same MCP server type are configured
|
||||
tools[mcpTool.id] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params, { toolCallId }) => {
|
||||
// 检查是否启用自动批准
|
||||
const server = getMcpServerByTool(mcpTool)
|
||||
let isAutoApproveEnabled = isToolAutoApproved(mcpTool, server, allowedTools)
|
||||
|
||||
// For hub invoke/exec, resolve the underlying tool and check its server's auto-approve config
|
||||
if (
|
||||
!isAutoApproveEnabled &&
|
||||
mcpTool.serverId === 'hub' &&
|
||||
(mcpTool.name === 'invoke' || mcpTool.name === 'exec')
|
||||
) {
|
||||
const underlyingToolName = (params as Record<string, unknown>)?.name as string | undefined
|
||||
if (underlyingToolName) {
|
||||
try {
|
||||
const resolved = await window.api.mcp.resolveHubTool(underlyingToolName)
|
||||
if (resolved) {
|
||||
const underlyingServer = store.getState().mcp.servers.find((s) => s.id === resolved.serverId)
|
||||
if (underlyingServer) {
|
||||
isAutoApproveEnabled = !underlyingServer.disabledAutoApproveTools?.includes(resolved.toolName)
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.warn('Failed to resolve hub tool for auto-approve check', err as Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let confirmed = true
|
||||
|
||||
if (!isAutoApproveEnabled) {
|
||||
// Register mapping so confirmSameNameTools can batch-confirm pending tools.
|
||||
// For hub invoke/exec, use the underlying tool name so tools targeting the
|
||||
// same underlying server+tool are grouped together.
|
||||
const mappingName =
|
||||
mcpTool.serverId === 'hub' && (mcpTool.name === 'invoke' || mcpTool.name === 'exec')
|
||||
? ((params as Record<string, unknown>)?.name as string) || mcpTool.name
|
||||
: mcpTool.name
|
||||
setToolIdToNameMapping(toolCallId, mappingName)
|
||||
|
||||
// Send system notification for tool approval
|
||||
sendToolApprovalNotification(mcpTool.name)
|
||||
|
||||
// 请求用户确认
|
||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||
confirmed = await requestToolConfirmation(toolCallId)
|
||||
|
||||
if (confirmed) {
|
||||
// Auto-confirm other pending tools with the same name
|
||||
confirmSameNameTools(mappingName)
|
||||
}
|
||||
}
|
||||
|
||||
if (!confirmed) {
|
||||
// 用户拒绝执行工具
|
||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `User declined to execute tool "${mcpTool.name}".`
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户确认或自动批准,执行工具
|
||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: mcpTool,
|
||||
arguments: params,
|
||||
status: 'pending',
|
||||
toolCallId
|
||||
}
|
||||
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
return Promise.reject(result)
|
||||
}
|
||||
// 返回工具执行结果
|
||||
return result
|
||||
},
|
||||
// 将多模态结果 (image/audio/resource blob) 转为文本摘要,避免 base64 超出消息大小限制。
|
||||
// 图片/音频已通过 IMAGE_COMPLETE chunk 展示给用户。
|
||||
// TODO: 待 AI SDK 支持 provider 感知后,可按 provider 返回 media 格式。
|
||||
toModelOutput(rawOutput: unknown) {
|
||||
// rawOutput 来自上方 execute 的 return result,类型始终为 MCPCallToolResponse
|
||||
// mcpResultToTextSummary 内部已有 null/content 校验,不会因意外输入崩溃
|
||||
const result = rawOutput as MCPCallToolResponse
|
||||
return { type: 'text' as const, value: mcpResultToTextSummary(result) }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
@@ -1,605 +0,0 @@
|
||||
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
getModelSupportedVerbosity,
|
||||
isAnthropicModel,
|
||||
isGeminiModel,
|
||||
isGrokModel,
|
||||
isOpenAIModel,
|
||||
isQwenMTModel,
|
||||
isReasoningModel,
|
||||
isSupportFlexServiceTierModel,
|
||||
isSupportVerbosityModel
|
||||
} from '@renderer/config/models'
|
||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getProviderById } from '@renderer/services/ProviderService'
|
||||
import {
|
||||
type Assistant,
|
||||
type GroqServiceTier,
|
||||
GroqServiceTiers,
|
||||
type GroqSystemProvider,
|
||||
isGroqServiceTier,
|
||||
isGroqSystemProvider,
|
||||
isOpenAIServiceTier,
|
||||
isTranslateAssistant,
|
||||
type Model,
|
||||
type NotGroqProvider,
|
||||
type OpenAIServiceTier,
|
||||
OpenAIServiceTiers,
|
||||
type Provider,
|
||||
type ServiceTier,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider'
|
||||
import type { JSONValue } from 'ai'
|
||||
import { t } from 'i18next'
|
||||
import { merge } from 'lodash'
|
||||
import type { OllamaProviderOptions } from 'ollama-ai-provider-v2'
|
||||
|
||||
import { addAnthropicHeaders } from '../prepareParams/header'
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import type { ProviderCapabilities } from '../types'
|
||||
import { buildGeminiGenerateImageParams } from './image'
|
||||
import {
|
||||
getAnthropicReasoningParams,
|
||||
getBedrockReasoningParams,
|
||||
getCustomParameters,
|
||||
getGeminiReasoningParams,
|
||||
getOllamaReasoningParams,
|
||||
getOpenAIReasoningParams,
|
||||
getReasoningEffort,
|
||||
getXAIReasoningParams
|
||||
} from './reasoning'
|
||||
import { getWebSearchParams } from './websearch'
|
||||
|
||||
const logger = loggerService.withContext('aiCore.utils.options')
|
||||
|
||||
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTier) ||
|
||||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
} else {
|
||||
return serviceTier
|
||||
}
|
||||
}
|
||||
|
||||
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTier) ||
|
||||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
} else {
|
||||
return serviceTier
|
||||
}
|
||||
}
|
||||
|
||||
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
|
||||
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
|
||||
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
|
||||
const serviceTierSetting = provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (isGroqSystemProvider(provider)) {
|
||||
return toGroqServiceTier(model, serviceTierSetting)
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
return toOpenAIServiceTier(model, serviceTierSetting)
|
||||
}
|
||||
}
|
||||
|
||||
function getVerbosity(model: Model): OpenAIVerbosity {
|
||||
if (!isSupportVerbosityModel(model) || !isSupportVerbosityProvider(getProviderById(model.provider)!)) {
|
||||
return undefined
|
||||
}
|
||||
const openAI = getStoreSetting('openAI')
|
||||
|
||||
const userVerbosity = openAI.verbosity
|
||||
|
||||
if (userVerbosity) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
// Use user's verbosity if supported, otherwise use the first supported option
|
||||
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
|
||||
return verbosity
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract AI SDK standard parameters from custom parameters
|
||||
* These parameters should be passed directly to streamText() instead of providerOptions
|
||||
*/
|
||||
export function extractAiSdkStandardParams(customParams: Record<string, any>): {
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
providerParams: Record<string, any>
|
||||
} {
|
||||
const standardParams: Partial<Record<AiSdkParam, any>> = {}
|
||||
const providerParams: Record<string, any> = {}
|
||||
|
||||
for (const [key, value] of Object.entries(customParams)) {
|
||||
if (isAiSdkParam(key)) {
|
||||
standardParams[key] = value
|
||||
} else {
|
||||
providerParams[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return { standardParams, providerParams }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
* 返回格式:{
|
||||
* providerOptions: { 'providerId': providerOptions },
|
||||
* standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed }
|
||||
* }
|
||||
*
|
||||
* Custom parameters are split into two categories:
|
||||
* 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText()
|
||||
* 2. Provider-specific parameters - merged into providerOptions
|
||||
*/
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
actualProvider: Provider,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): {
|
||||
providerOptions: Record<string, Record<string, JSONValue>>
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
} {
|
||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId })
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
const serviceTier = getServiceTier(model, actualProvider)
|
||||
const textVerbosity = getVerbosity(model)
|
||||
|
||||
// 根据 provider ID 构建特定选项
|
||||
switch (rawProviderId) {
|
||||
case 'openai':
|
||||
case 'openai-chat':
|
||||
case 'azure':
|
||||
case 'azure-responses':
|
||||
case 'huggingface':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||
break
|
||||
case 'anthropic':
|
||||
case 'azure-anthropic':
|
||||
case 'google-vertex-anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'google':
|
||||
case 'google-vertex':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'xai':
|
||||
case 'xai-responses':
|
||||
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'bedrock':
|
||||
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'cherryin':
|
||||
providerSpecificOptions = buildCherryInProviderOptions(
|
||||
assistant,
|
||||
model,
|
||||
capabilities,
|
||||
actualProvider,
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
)
|
||||
break
|
||||
case SystemProviderIds.ollama:
|
||||
providerSpecificOptions = buildOllamaProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case SystemProviderIds.gateway:
|
||||
providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||
break
|
||||
case 'deepseek':
|
||||
case 'openrouter':
|
||||
case 'openai-compatible':
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
|
||||
// Merge serviceTier and textVerbosity
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
[rawProviderId]: {
|
||||
...providerSpecificOptions[rawProviderId],
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
logger.debug('Built providerSpecificOptions', { providerSpecificOptions })
|
||||
/**
|
||||
* Retrieve custom parameters and separate standard parameters from provider-specific parameters.
|
||||
*/
|
||||
const customParams = getCustomParameters(assistant)
|
||||
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
||||
logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams })
|
||||
|
||||
/**
|
||||
* Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions.
|
||||
* For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic')
|
||||
* For regular providers, this will be the provider itself
|
||||
*/
|
||||
const actualAiSdkProviderIds = Object.keys(providerSpecificOptions)
|
||||
const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params
|
||||
|
||||
// For openai-compatible providers, auto-convert reasoning_effort (snake_case) to reasoningEffort (camelCase).
|
||||
// The AI SDK's openai-compatible provider overwrites reasoning_effort to undefined,
|
||||
// but accepts reasoningEffort. See: https://github.com/CherryHQ/cherry-studio/issues/11987
|
||||
if (primaryAiSdkProviderId === 'openai-compatible' && 'reasoning_effort' in providerParams) {
|
||||
if (!('reasoningEffort' in providerParams)) {
|
||||
providerParams.reasoningEffort = providerParams.reasoning_effort
|
||||
}
|
||||
delete providerParams.reasoning_effort
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge custom parameters into providerSpecificOptions.
|
||||
* Simple logic:
|
||||
* 1. If key is in actualAiSdkProviderIds → merge directly (user knows the actual AI SDK provider ID)
|
||||
* 2. If key == rawProviderId:
|
||||
* - If it's gateway/ollama → preserve (they need their own config for routing/options)
|
||||
* - Otherwise → map to primary (this is a proxy provider like cherryin)
|
||||
* 3. Otherwise → treat as regular parameter, merge to primary provider
|
||||
*
|
||||
* Example:
|
||||
* - User writes `cherryin: { opt: 'val' }` → mapped to `google: { opt: 'val' }` (case 2, proxy)
|
||||
* - User writes `gateway: { order: [...] }` → stays as `gateway: { order: [...] }` (case 2, routing config)
|
||||
* - User writes `google: { opt: 'val' }` → stays as `google: { opt: 'val' }` (case 1)
|
||||
* - User writes `customKey: 'val'` → merged to `google: { customKey: 'val' }` (case 3)
|
||||
*/
|
||||
for (const key of Object.keys(providerParams)) {
|
||||
if (actualAiSdkProviderIds.includes(key)) {
|
||||
// Case 1: Key is an actual AI SDK provider ID - merge directly
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
[key]: {
|
||||
...providerSpecificOptions[key],
|
||||
...providerParams[key]
|
||||
}
|
||||
}
|
||||
} else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) {
|
||||
// Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider)
|
||||
// Gateway is special: it needs routing config preserved
|
||||
if (key === SystemProviderIds.gateway) {
|
||||
// Preserve gateway config for routing
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
[key]: {
|
||||
...providerSpecificOptions[key],
|
||||
...providerParams[key]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Proxy provider (cherryin, etc.) - map to actual AI SDK provider
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
[primaryAiSdkProviderId]: {
|
||||
...providerSpecificOptions[primaryAiSdkProviderId],
|
||||
...providerParams[key]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Case 3: Regular parameter - merge to primary provider
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
[primaryAiSdkProviderId]: {
|
||||
...providerSpecificOptions[primaryAiSdkProviderId],
|
||||
[key]: providerParams[key]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions })
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
||||
return {
|
||||
providerOptions: providerSpecificOptions,
|
||||
standardParams
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 OpenAI 特定的 providerOptions
|
||||
*/
|
||||
function buildOpenAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
|
||||
serviceTier: OpenAIServiceTier,
|
||||
textVerbosity?: OpenAIVerbosity
|
||||
): Record<string, OpenAIResponsesProviderOptions> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: OpenAIResponsesProviderOptions = {}
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams,
|
||||
// TODO: Remove this workaround after migrating to @ai-sdk/open-responses (#13462)
|
||||
// Bypass @ai-sdk/openai's model ID allowlist for reasoning detection.
|
||||
// Third-party providers often use non-canonical model IDs (e.g., "openai/gpt-5.2")
|
||||
// that fail the SDK's startsWith() checks, causing reasoning params to be silently dropped.
|
||||
...(isReasoningModel(model) && { forceReasoning: true })
|
||||
}
|
||||
}
|
||||
const provider = getProviderById(model.provider)
|
||||
|
||||
if (!provider) {
|
||||
throw new Error(`Provider ${model.provider} not found`)
|
||||
}
|
||||
|
||||
if (isSupportVerbosityModel(model) && isSupportVerbosityProvider(provider)) {
|
||||
const openAI = getStoreSetting<'openAI'>('openAI')
|
||||
const userVerbosity = openAI?.verbosity
|
||||
|
||||
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
// Use user's verbosity if supported, otherwise use the first supported option
|
||||
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
|
||||
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
textVerbosity: verbosity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 支持配置是否在服务端持久化
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
serviceTier,
|
||||
textVerbosity,
|
||||
store: false
|
||||
}
|
||||
|
||||
return {
|
||||
openai: providerOptions
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Anthropic 特定的 providerOptions
|
||||
*/
|
||||
function buildAnthropicProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, AnthropicProviderOptions> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: AnthropicProviderOptions = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getAnthropicReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
anthropic: {
|
||||
...providerOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Gemini 特定的 providerOptions
|
||||
*/
|
||||
function buildGeminiProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, GoogleGenerativeAIProviderOptions> {
|
||||
const { enableReasoning, enableGenerateImage } = capabilities
|
||||
let providerOptions: GoogleGenerativeAIProviderOptions = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getGeminiReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...buildGeminiGenerateImageParams()
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
google: {
|
||||
...providerOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildXAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, XaiProviderOptions> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getXAIReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
xai: {
|
||||
...providerOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildCherryInProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
|
||||
actualProvider: Provider,
|
||||
serviceTier: OpenAIServiceTier,
|
||||
textVerbosity: OpenAIVerbosity
|
||||
): Record<string, OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions> {
|
||||
switch (actualProvider.type) {
|
||||
case 'openai':
|
||||
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
|
||||
case 'openai-response':
|
||||
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||
case 'anthropic':
|
||||
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
case 'gemini':
|
||||
return buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
|
||||
default:
|
||||
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build Bedrock providerOptions
|
||||
*/
|
||||
function buildBedrockProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, BedrockProviderOptions> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: BedrockProviderOptions = {}
|
||||
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getBedrockReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
const betaHeaders = addAnthropicHeaders(assistant, model)
|
||||
if (betaHeaders.length > 0) {
|
||||
providerOptions.anthropicBeta = betaHeaders
|
||||
}
|
||||
|
||||
return {
|
||||
bedrock: providerOptions
|
||||
}
|
||||
}
|
||||
|
||||
function buildOllamaProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, OllamaProviderOptions> {
|
||||
const { enableReasoning } = capabilities
|
||||
let options = {}
|
||||
|
||||
if (enableReasoning) {
|
||||
options = {
|
||||
...options,
|
||||
...getOllamaReasoningParams(assistant, model)
|
||||
}
|
||||
}
|
||||
|
||||
return { ollama: options }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建通用的 providerOptions(用于其他 provider)
|
||||
*/
|
||||
function buildGenericProviderOptions(
|
||||
providerId: string,
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>
|
||||
): Record<string, any> {
|
||||
const { enableWebSearch } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
logger.debug('reasoningParams', reasoningParams)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchParams = getWebSearchParams(model)
|
||||
providerOptions = merge({}, providerOptions, webSearchParams)
|
||||
}
|
||||
|
||||
// 特殊处理 Qwen MT
|
||||
if (isQwenMTModel(model)) {
|
||||
if (isTranslateAssistant(assistant)) {
|
||||
const targetLanguage = assistant.targetLanguage
|
||||
const translationOptions = {
|
||||
source_lang: 'auto',
|
||||
target_lang: mapLanguageToQwenMTModel(targetLanguage)
|
||||
} as const
|
||||
if (!translationOptions.target_lang) {
|
||||
throw new Error(t('translate.error.not_supported', { language: targetLanguage.value }))
|
||||
}
|
||||
providerOptions.translation_options = translationOptions
|
||||
} else {
|
||||
throw new Error(t('translate.error.chat_qwen_mt'))
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
[providerId]: providerOptions
|
||||
}
|
||||
}
|
||||
|
||||
function buildAIGatewayOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: Pick<ProviderCapabilities, 'enableReasoning' | 'enableWebSearch' | 'enableGenerateImage'>,
|
||||
serviceTier: OpenAIServiceTier,
|
||||
textVerbosity?: OpenAIVerbosity
|
||||
): Record<
|
||||
string,
|
||||
| OpenAIResponsesProviderOptions
|
||||
| AnthropicProviderOptions
|
||||
| GoogleGenerativeAIProviderOptions
|
||||
| Record<string, unknown>
|
||||
> {
|
||||
if (isAnthropicModel(model)) {
|
||||
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
} else if (isOpenAIModel(model)) {
|
||||
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||
} else if (isGeminiModel(model)) {
|
||||
return buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
} else if (isGrokModel(model)) {
|
||||
return buildXAIProviderOptions(assistant, model, capabilities)
|
||||
} else {
|
||||
return buildGenericProviderOptions('openai-compatible', assistant, model, capabilities)
|
||||
}
|
||||
}
|
||||
@@ -1,999 +0,0 @@
|
||||
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
getModelSupportedReasoningEffortOptions,
|
||||
isClaude46SeriesModel,
|
||||
isDeepSeekHybridInferenceModel,
|
||||
isDoubaoSeed18Model,
|
||||
isDoubaoSeedAfter251015,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isGemini3ThinkingTokenModel,
|
||||
isGrok4FastReasoningModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isOpenAIModel,
|
||||
isOpenAIOpenWeightModel,
|
||||
isOpenAIReasoningModel,
|
||||
isQwen35to39Model,
|
||||
isQwenAlwaysThinkModel,
|
||||
isQwenReasoningModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenHunyuanModel,
|
||||
isSupportedThinkingTokenKimiModel,
|
||||
isSupportedThinkingTokenMiMoModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isSupportedThinkingTokenZhipuModel,
|
||||
isSupportNoneReasoningEffortModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
|
||||
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
|
||||
import type { OpenAIReasoningEffort, OpenAIReasoningSummary } from '@renderer/types/aiCoreTypes'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import { toInteger } from 'lodash'
|
||||
import type { OllamaProviderOptions } from 'ollama-ai-provider-v2'
|
||||
|
||||
const logger = loggerService.withContext('reasoning')
|
||||
|
||||
type ReasoningEffortOptionalParams = {
|
||||
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
|
||||
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string; enabled?: boolean } | OpenAI.Reasoning
|
||||
reasoningEffort?: OpenAIReasoningEffort
|
||||
// WARN: This field will be overwrite to undefined by aisdk if the provider is openai-compatible. Use reasoningEffort instead.
|
||||
reasoning_effort?: OpenAIReasoningEffort
|
||||
enable_thinking?: boolean
|
||||
thinking_budget?: number
|
||||
incremental_output?: boolean
|
||||
enable_reasoning?: boolean
|
||||
// nvidia, etc.
|
||||
chat_template_kwargs?: {
|
||||
thinking?: boolean
|
||||
enable_thinking?: boolean
|
||||
thinking_budget?: number
|
||||
}
|
||||
extra_body?: {
|
||||
google?: {
|
||||
thinking_config: {
|
||||
thinking_budget: number
|
||||
include_thoughts?: boolean
|
||||
}
|
||||
}
|
||||
thinking?: {
|
||||
type: 'enabled' | 'disabled'
|
||||
}
|
||||
thinking_budget?: number
|
||||
reasoning_effort?: OpenAIReasoningEffort
|
||||
}
|
||||
disable_reasoning?: boolean
|
||||
// Add any other potential reasoning-related keys here if they exist
|
||||
}
|
||||
|
||||
// The function is only for generic provider. May extract some logics to independent provider
|
||||
export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
const provider = getProviderByModel(model)
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
if (provider.id === 'groq') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isOpenAIDeepResearchModel(model)) {
|
||||
return {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
}
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// reasoningEffort is not set, no extra reasoning setting
|
||||
// Generally, for every model which supports reasoning control, the reasoning effort won't be undefined.
|
||||
// It's for some reasoning models that don't support reasoning control, such as deepseek reasoner.
|
||||
if (!reasoningEffort || reasoningEffort === 'default') {
|
||||
return {}
|
||||
}
|
||||
|
||||
// Handle 'none' reasoningEffort. It's explicitly off.
|
||||
if (reasoningEffort === 'none') {
|
||||
// openrouter: use reasoning
|
||||
if (model.provider === SystemProviderIds.openrouter) {
|
||||
if (isSupportNoneReasoningEffortModel(model) && reasoningEffort === 'none') {
|
||||
return { reasoning: { effort: 'none' } }
|
||||
}
|
||||
return { reasoning: { enabled: false, exclude: true } }
|
||||
}
|
||||
|
||||
// nvidia: must use chat_template_kwargs
|
||||
// Since limited documentation, it's hard to find what parameters should be set
|
||||
// only part of mainstream oss model covered, all verified by nvidia api
|
||||
if (model.provider === SystemProviderIds.nvidia) {
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return { chat_template_kwargs: { enable_thinking: false } }
|
||||
} else if (isDeepSeekHybridInferenceModel(model)) {
|
||||
return { chat_template_kwargs: { thinking: false } }
|
||||
} else if (isSupportedThinkingTokenKimiModel(model)) {
|
||||
return { chat_template_kwargs: { thinking: false } }
|
||||
} else if (isSupportedThinkingTokenZhipuModel(model)) {
|
||||
return { chat_template_kwargs: { enable_thinking: false } }
|
||||
}
|
||||
}
|
||||
|
||||
// providers that use enable_thinking
|
||||
if (
|
||||
(isSupportEnableThinkingProvider(provider) &&
|
||||
(isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model))) ||
|
||||
(provider.id === SystemProviderIds.dashscope &&
|
||||
(isDeepSeekHybridInferenceModel(model) ||
|
||||
isSupportedThinkingTokenZhipuModel(model) ||
|
||||
isSupportedThinkingTokenKimiModel(model)))
|
||||
) {
|
||||
return { enable_thinking: false }
|
||||
}
|
||||
|
||||
// together
|
||||
if (provider.id === SystemProviderIds.together) {
|
||||
return { reasoning: { enabled: false } }
|
||||
}
|
||||
|
||||
// gemini
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.warn(`Model ${model.id} cannot disable reasoning. Fallback to empty reasoning param.`)
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
// use thinking, doubao, zhipu, etc.
|
||||
if (
|
||||
isSupportedThinkingTokenDoubaoModel(model) ||
|
||||
isSupportedThinkingTokenZhipuModel(model) ||
|
||||
isSupportedThinkingTokenMiMoModel(model) ||
|
||||
isSupportedThinkingTokenKimiModel(model)
|
||||
) {
|
||||
if (provider.id === SystemProviderIds.cerebras) {
|
||||
return {
|
||||
disable_reasoning: true
|
||||
}
|
||||
}
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
|
||||
// Deepseek, default behavior is non-thinking
|
||||
if (isDeepSeekHybridInferenceModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// GPT 5.1, GPT 5.2, or newer
|
||||
if (isSupportNoneReasoningEffortModel(model)) {
|
||||
return {
|
||||
reasoningEffort: 'none'
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen 3.5 without direct enable_thinking
|
||||
// https://huggingface.co/Qwen/Qwen3.5-397B-A17B#instruct-or-non-thinking-mode
|
||||
if (isQwen35to39Model(model)) {
|
||||
return {
|
||||
chat_template_kwargs: {
|
||||
enable_thinking: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(`Model ${model.id} doesn't match any disable reasoning behavior. Fallback to empty reasoning param.`)
|
||||
return {}
|
||||
}
|
||||
|
||||
// reasoningEffort有效的情况
|
||||
// https://creator.poe.com/docs/external-applications/openai-compatible-api#additional-considerations
|
||||
// Poe provider - supports custom bot parameters via extra_body
|
||||
if (provider.id === SystemProviderIds.poe) {
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return {
|
||||
extra_body: {
|
||||
reasoning_effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models use thinking_budget parameter in extra_body
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const tokenLimit = findTokenLimit(model.id)
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
|
||||
if (!tokenLimit) {
|
||||
logger.warn(
|
||||
`No token limit configuration found for Claude model "${model.id}" on Poe provider. ` +
|
||||
`Reasoning effort setting "${reasoningEffort}" will not be applied.`
|
||||
)
|
||||
return {}
|
||||
}
|
||||
|
||||
let budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
|
||||
budgetTokens = Math.floor(Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)))
|
||||
|
||||
return {
|
||||
extra_body: {
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini models use thinking_budget parameter in extra_body
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const tokenLimit = findTokenLimit(model.id)
|
||||
let budgetTokens: number | undefined
|
||||
if (tokenLimit && reasoningEffort !== 'auto') {
|
||||
budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
|
||||
} else if (!tokenLimit && reasoningEffort !== 'auto') {
|
||||
logger.warn(
|
||||
`No token limit configuration found for Gemini model "${model.id}" on Poe provider. ` +
|
||||
`Using auto (-1) instead of requested effort "${reasoningEffort}".`
|
||||
)
|
||||
}
|
||||
return {
|
||||
extra_body: {
|
||||
thinking_budget: budgetTokens ?? -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Poe reasoning model not in known categories (GPT-5, Claude, Gemini)
|
||||
logger.warn(
|
||||
`Poe provider reasoning model "${model.id}" does not match known categories ` +
|
||||
`(GPT-5, Claude, Gemini). Reasoning effort setting "${reasoningEffort}" will not be applied.`
|
||||
)
|
||||
return {}
|
||||
}
|
||||
|
||||
// OpenRouter models
|
||||
if (model.provider === SystemProviderIds.openrouter) {
|
||||
// Grok 4 Fast doesn't support effort levels, always use enabled: true
|
||||
if (isGrok4FastReasoningModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
enabled: true // Ignore effort level, just enable reasoning
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other OpenRouter models that support effort levels
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const tokenLimit = findTokenLimit(modelId)
|
||||
let budgetTokens: number | undefined
|
||||
if (tokenLimit) {
|
||||
budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
|
||||
}
|
||||
|
||||
// nvidia: must use chat_template_kwargs
|
||||
// Since limited documentation, it's hard to find what parameters should be set
|
||||
// only part of mainstream oss model covered, all verified by nvidia api
|
||||
if (model.provider === SystemProviderIds.nvidia) {
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
const enableThinkingConfig = isQwenAlwaysThinkModel(model) ? {} : { enable_thinking: true }
|
||||
return {
|
||||
chat_template_kwargs: {
|
||||
...enableThinkingConfig,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
} else if (isDeepSeekHybridInferenceModel(model)) {
|
||||
return { chat_template_kwargs: { thinking: true } }
|
||||
} else if (isSupportedThinkingTokenKimiModel(model)) {
|
||||
return { chat_template_kwargs: { thinking: true } }
|
||||
} else if (isSupportedThinkingTokenZhipuModel(model)) {
|
||||
return { chat_template_kwargs: { enable_thinking: true } }
|
||||
}
|
||||
}
|
||||
|
||||
// See https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions
|
||||
if (model.provider === SystemProviderIds.silicon) {
|
||||
if (
|
||||
isDeepSeekHybridInferenceModel(model) ||
|
||||
isSupportedThinkingTokenZhipuModel(model) ||
|
||||
isSupportedThinkingTokenQwenModel(model) ||
|
||||
isSupportedThinkingTokenHunyuanModel(model)
|
||||
) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
// Hard-encoded maximum, only for silicon
|
||||
thinking_budget: budgetTokens ? toInteger(Math.max(budgetTokens, 32768)) : undefined
|
||||
}
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
// DeepSeek hybrid inference models, v3.1 and maybe more in the future
|
||||
// 不同的 provider 有不同的思考控制方式,在这里统一解决
|
||||
if (isDeepSeekHybridInferenceModel(model)) {
|
||||
if (isSystemProvider(provider)) {
|
||||
switch (provider.id) {
|
||||
case SystemProviderIds.dashscope:
|
||||
return {
|
||||
enable_thinking: true,
|
||||
incremental_output: true
|
||||
}
|
||||
// TODO: 支持 new-api类型
|
||||
case SystemProviderIds['new-api']:
|
||||
case SystemProviderIds.cherryin: {
|
||||
return {
|
||||
extra_body: {
|
||||
thinking: {
|
||||
type: 'enabled' // auto is invalid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case SystemProviderIds.hunyuan:
|
||||
case SystemProviderIds['tencent-cloud-ti']:
|
||||
case SystemProviderIds.doubao:
|
||||
case SystemProviderIds.deepseek:
|
||||
case SystemProviderIds.aihubmix:
|
||||
case SystemProviderIds.sophnet:
|
||||
case SystemProviderIds.ppio:
|
||||
case SystemProviderIds.dmxapi:
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled' // auto is invalid
|
||||
}
|
||||
}
|
||||
case SystemProviderIds.openrouter:
|
||||
case SystemProviderIds.together:
|
||||
return {
|
||||
reasoning: {
|
||||
enabled: true
|
||||
}
|
||||
}
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.warn(
|
||||
`Use default thinking options for provider ${provider.name} as DeepSeek v3.1+ thinking control method is unknown`
|
||||
)
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter models, use reasoning
|
||||
// FIXME: duplicated openrouter handling. remove one
|
||||
if (model.provider === SystemProviderIds.openrouter) {
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://help.aliyun.com/zh/model-studio/deep-thinking
|
||||
if (provider.id === SystemProviderIds.dashscope) {
|
||||
// For dashscope: Qwen, DeepSeek, and GLM models use enable_thinking to control thinking
|
||||
// No effort, only on/off
|
||||
if (
|
||||
isQwenReasoningModel(model) ||
|
||||
isSupportedThinkingTokenZhipuModel(model) ||
|
||||
isSupportedThinkingTokenKimiModel(model)
|
||||
) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.together.ai/reference/chat-completions-1#body-reasoning-effort
|
||||
if (provider.id === SystemProviderIds.together) {
|
||||
let adjustedReasoningEffort: 'low' | 'medium' | 'high' = 'medium'
|
||||
switch (reasoningEffort) {
|
||||
case 'minimal':
|
||||
adjustedReasoningEffort = 'low'
|
||||
break
|
||||
case 'xhigh':
|
||||
adjustedReasoningEffort = 'high'
|
||||
break
|
||||
case 'auto':
|
||||
adjustedReasoningEffort = 'medium'
|
||||
break
|
||||
default:
|
||||
adjustedReasoningEffort = reasoningEffort
|
||||
break
|
||||
}
|
||||
return {
|
||||
// Only low, medium, high
|
||||
reasoningEffort: adjustedReasoningEffort,
|
||||
reasoning: { enabled: true }
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen models, use enable_thinking
|
||||
if (isQwenReasoningModel(model)) {
|
||||
const supportEnableThinking = isSupportEnableThinkingProvider(provider)
|
||||
const enableThinkingConfig = isQwenAlwaysThinkModel(model) ? {} : { enable_thinking: true }
|
||||
if (supportEnableThinking) {
|
||||
return {
|
||||
...enableThinkingConfig,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
chat_template_kwargs: {
|
||||
...enableThinkingConfig,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hunyuan models, use enable_thinking
|
||||
if (isSupportedThinkingTokenHunyuanModel(model) && isSupportEnableThinkingProvider(provider)) {
|
||||
return {
|
||||
enable_thinking: true
|
||||
}
|
||||
}
|
||||
|
||||
// Grok models/Perplexity models/OpenAI models, use reasoning_effort
|
||||
if (isSupportedReasoningEffortModel(model)) {
|
||||
// 检查模型是否支持所选选项
|
||||
const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default')
|
||||
if (supportedOptions?.includes(reasoningEffort)) {
|
||||
return {
|
||||
reasoningEffort
|
||||
}
|
||||
} else {
|
||||
// 如果不支持,fallback到第一个支持的值
|
||||
return {
|
||||
reasoningEffort: supportedOptions?.[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gemini series, openai compatible api
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
|
||||
if (isGemini3ThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoningEffort
|
||||
}
|
||||
}
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: -1,
|
||||
include_thoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: budgetTokens ?? -1,
|
||||
include_thoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models, openai compatible api
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
? Math.floor(Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio)))
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use thinking, doubao, zhipu, etc.
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
if (isDoubaoSeedAfter251015(model) || isDoubaoSeed18Model(model)) {
|
||||
return { reasoningEffort }
|
||||
}
|
||||
if (reasoningEffort === 'high') {
|
||||
return { thinking: { type: 'enabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
|
||||
return { thinking: { type: 'auto' } }
|
||||
}
|
||||
// 其他情况不带 thinking 字段
|
||||
return {}
|
||||
}
|
||||
if (isSupportedThinkingTokenZhipuModel(model)) {
|
||||
if (provider.id === SystemProviderIds.cerebras) {
|
||||
return {}
|
||||
}
|
||||
return { thinking: { type: 'enabled' } }
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenMiMoModel(model) || isSupportedThinkingTokenKimiModel(model)) {
|
||||
return {
|
||||
thinking: { type: 'enabled' }
|
||||
}
|
||||
}
|
||||
|
||||
// Default case: no special thinking settings
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get OpenAI reasoning parameters
|
||||
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
|
||||
* For official OpenAI provider only
|
||||
*/
|
||||
export function getOpenAIReasoningParams(
|
||||
assistant: Assistant,
|
||||
model: Model
|
||||
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
let reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (!reasoningEffort || reasoningEffort === 'default') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
|
||||
reasoningEffort = 'medium'
|
||||
}
|
||||
|
||||
// 非OpenAI模型,但是Provider类型是responses/azure openai的情况
|
||||
if (!isOpenAIModel(model)) {
|
||||
return {
|
||||
reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI')
|
||||
const summaryText = openAI.summaryText
|
||||
|
||||
let reasoningSummary: OpenAIReasoningSummary = undefined
|
||||
|
||||
if (model.id.includes('o1-pro')) {
|
||||
reasoningSummary = undefined
|
||||
} else {
|
||||
reasoningSummary = summaryText
|
||||
}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoningEffort,
|
||||
reasoningSummary
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
// Conservative fallback token limit for models not in THINKING_TOKEN_MAP.
|
||||
const FALLBACK_TOKEN_LIMIT = { min: 1024, max: 16384 }
|
||||
|
||||
function computeBudgetTokens(
|
||||
tokenLimit: { min: number; max: number },
|
||||
effortRatio: number,
|
||||
maxTokens?: number
|
||||
): number {
|
||||
const budget = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min)
|
||||
const capped = maxTokens !== undefined ? Math.min(budget, maxTokens) : budget
|
||||
return Math.max(1024, capped)
|
||||
}
|
||||
|
||||
export function getThinkingBudget(
|
||||
maxTokens: number | undefined,
|
||||
reasoningEffort: string | undefined,
|
||||
modelId: string
|
||||
): number | undefined {
|
||||
if (reasoningEffort === undefined || reasoningEffort === 'none') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const tokenLimit = findTokenLimit(modelId)
|
||||
if (!tokenLimit) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return computeBudgetTokens(tokenLimit, EFFORT_RATIO[reasoningEffort], maxTokens)
|
||||
}
|
||||
|
||||
// Compute a fallback budgetTokens using a conservative token limit when
|
||||
// findTokenLimit() cannot determine the model's actual limit. This ensures
|
||||
// { type: 'enabled' } always carries a valid budget, which is required by
|
||||
// the Claude Agent SDK and the Anthropic Messages API.
|
||||
function getFallbackBudgetTokens(reasoningEffort: string | undefined): number {
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort ?? 'high'] ?? EFFORT_RATIO.high
|
||||
return computeBudgetTokens(FALLBACK_TOKEN_LIMIT, effortRatio)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Anthropic reasoning parameters.
|
||||
* Extracted from AnthropicAPIClient logic.
|
||||
*
|
||||
* Returns different parameter shapes depending on the model:
|
||||
* - **Claude 4.6**: `{ thinking: { type: 'adaptive' }, effort: 'low' | 'medium' | 'high' | 'max' }`
|
||||
* Uses the new adaptive thinking API with effort-based control.
|
||||
* - **Other Claude models** (4.0, 4.1, 4.5, etc.): `{ thinking: { type: 'enabled', budgetTokens: number } }`
|
||||
* Uses the classic thinking API with explicit token budget.
|
||||
*/
|
||||
export function getAnthropicReasoningParams(
|
||||
assistant: Assistant,
|
||||
model: Model
|
||||
): Pick<AnthropicProviderOptions, 'thinking' | 'effort'> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (!reasoningEffort || reasoningEffort === 'default') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (reasoningEffort === 'none') {
|
||||
return {
|
||||
thinking: {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude reasoning parameters
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
// Claude 4.6 uses adaptive thinking + effort parameters
|
||||
// Map reasoningEffort to Claude 4.6 supported effort values
|
||||
if (isClaude46SeriesModel(model)) {
|
||||
// Claude 4.6 supports: low, medium, high, max
|
||||
// Mapping rules: default/none -> no effort (uses default high)
|
||||
// minimal/low -> low
|
||||
// medium -> medium
|
||||
// high -> high
|
||||
// xhigh -> max
|
||||
const effortMap = {
|
||||
default: undefined,
|
||||
auto: undefined,
|
||||
minimal: 'low',
|
||||
low: 'low',
|
||||
medium: 'medium',
|
||||
high: 'high',
|
||||
xhigh: 'max'
|
||||
} as const satisfies Record<Exclude<ReasoningEffortOption, 'none'>, AnthropicProviderOptions['effort']>
|
||||
const effort = effortMap[reasoningEffort]
|
||||
return effort ? { thinking: { type: 'adaptive' }, effort } : { thinking: { type: 'adaptive' } }
|
||||
}
|
||||
|
||||
// Other Claude models continue using enabled + budgetTokens
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: budgetTokens ?? getFallbackBudgetTokens(reasoningEffort)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他使用claude端點的模型,比如Kimi,Minimax等等
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
// Always include budgetTokens to prevent Claude Agent SDK from converting
|
||||
// { type: 'enabled' } into '--thinking adaptive', which non-Anthropic
|
||||
// upstream providers do not support (they only accept 'enabled'/'disabled').
|
||||
return { thinking: { type: 'enabled', budgetTokens: budgetTokens ?? getFallbackBudgetTokens(reasoningEffort) } }
|
||||
}
|
||||
}
|
||||
|
||||
type GoogleThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
|
||||
|
||||
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogleThinkingLevel {
|
||||
switch (reasoningEffort) {
|
||||
case 'auto':
|
||||
case 'default':
|
||||
return undefined
|
||||
case 'none':
|
||||
return 'minimal'
|
||||
case 'minimal':
|
||||
return 'minimal'
|
||||
case 'low':
|
||||
return 'low'
|
||||
case 'medium':
|
||||
return 'medium'
|
||||
case 'high':
|
||||
case 'xhigh':
|
||||
return 'high'
|
||||
default:
|
||||
// Enforce all possible values are handled
|
||||
reasoningEffort satisfies never
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Gemini 推理参数
|
||||
* 从 GeminiAPIClient 中提取的逻辑
|
||||
* 注意:Gemini/GCP 端点所使用的 thinkingBudget 等参数应该按照驼峰命名法传递
|
||||
* 而在 Google 官方提供的 OpenAI 兼容端点中则使用蛇形命名法 thinking_budget
|
||||
*/
|
||||
export function getGeminiReasoningParams(
|
||||
assistant: Assistant,
|
||||
model: Model
|
||||
): Pick<GoogleGenerativeAIProviderOptions, 'thinkingConfig'> {
|
||||
if (!isReasoningModel(model) || !isSupportedThinkingTokenGeminiModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (!reasoningEffort || reasoningEffort === 'default') {
|
||||
return {}
|
||||
}
|
||||
|
||||
let thinkingLevel: GoogleThinkingLevel | null = null
|
||||
const includeThoughts = reasoningEffort !== 'none'
|
||||
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
|
||||
if (isGemini3ThinkingTokenModel(model)) {
|
||||
thinkingLevel = mapToGeminiThinkingLevel(reasoningEffort)
|
||||
if (thinkingLevel === 'minimal' && getLowerBaseModelName(model.id).includes('pro')) {
|
||||
thinkingLevel = 'low'
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingLevel !== null) {
|
||||
// Gemini 3 branch. thinkingLevel can be undefined (auto) or a specific level.
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts,
|
||||
thinkingLevel
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Old models
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts,
|
||||
thinkingBudget: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (reasoningEffort === 'none') {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts,
|
||||
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
|
||||
const budget = Math.floor((max - min) * effortRatio + min)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts,
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get XAI-specific reasoning parameters
|
||||
* This function should only be called for XAI provider models
|
||||
* @param assistant - The assistant configuration
|
||||
* @param model - The model being used
|
||||
* @returns XAI-specific reasoning parameters
|
||||
*/
|
||||
export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick<XaiProviderOptions, 'reasoningEffort'> {
|
||||
if (!isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||
|
||||
switch (reasoningEffort) {
|
||||
case 'auto':
|
||||
case 'minimal':
|
||||
case 'medium':
|
||||
return { reasoningEffort: 'low' }
|
||||
case 'low':
|
||||
case 'high':
|
||||
return { reasoningEffort }
|
||||
case 'xhigh':
|
||||
return { reasoningEffort: 'high' }
|
||||
case 'default':
|
||||
case 'none':
|
||||
default:
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Bedrock reasoning parameters
|
||||
*/
|
||||
export function getBedrockReasoningParams(
|
||||
assistant: Assistant,
|
||||
model: Model
|
||||
): Pick<BedrockProviderOptions, 'reasoningConfig'> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined || reasoningEffort === 'default') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (reasoningEffort === 'none') {
|
||||
return {
|
||||
reasoningConfig: {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only apply thinking budget for Claude reasoning models
|
||||
if (!isSupportedThinkingTokenClaudeModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// Claude 4.6 uses adaptive thinking + maxReasoningEffort
|
||||
if (isClaude46SeriesModel(model)) {
|
||||
const effortMap = {
|
||||
auto: undefined,
|
||||
minimal: 'low',
|
||||
low: 'low',
|
||||
medium: 'medium',
|
||||
high: 'high',
|
||||
xhigh: 'max'
|
||||
} as const satisfies Record<
|
||||
Exclude<ReasoningEffortOption, 'none' | 'default'>,
|
||||
NonNullable<BedrockProviderOptions['reasoningConfig']>['maxReasoningEffort']
|
||||
>
|
||||
const maxReasoningEffort = effortMap[reasoningEffort]
|
||||
return maxReasoningEffort
|
||||
? { reasoningConfig: { type: 'adaptive', maxReasoningEffort } }
|
||||
: { reasoningConfig: { type: 'adaptive' } }
|
||||
}
|
||||
|
||||
// Other Claude models use enabled + budgetTokens
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const budgetTokens = getThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
return {
|
||||
reasoningConfig: {
|
||||
type: 'enabled',
|
||||
budgetTokens: budgetTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Ollama reasoning parameters
|
||||
* Handles the `think` parameter for Ollama models
|
||||
*
|
||||
* - GPT-OSS models: accept 'low' | 'medium' | 'high' string values
|
||||
* - Other models: boolean only (true/false)
|
||||
*/
|
||||
export function getOllamaReasoningParams(assistant: Assistant, model: Model): Pick<OllamaProviderOptions, 'think'> {
|
||||
const reasoningEffort = assistant.settings?.reasoning_effort
|
||||
|
||||
if (isOpenAIOpenWeightModel(model)) {
|
||||
// gpt-oss models accept 'low' | 'medium' | 'high' string values
|
||||
if (reasoningEffort === 'low' || reasoningEffort === 'medium' || reasoningEffort === 'high') {
|
||||
return { think: reasoningEffort }
|
||||
} else if (reasoningEffort === 'none') {
|
||||
return { think: false }
|
||||
}
|
||||
return { think: true }
|
||||
}
|
||||
|
||||
// Other models: boolean only. undefined defaults to true (user enabled reasoning)
|
||||
return { think: reasoningEffort !== 'none' }
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
*/
|
||||
export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
// Parse JSON type parameters
|
||||
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
|
||||
// The UI stores JSON type params as strings (e.g., '{"key":"value"}')
|
||||
// This function parses them into objects before sending to the API
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
try {
|
||||
return { ...acc, [param.name]: JSON.parse(value) }
|
||||
} catch {
|
||||
return { ...acc, [param.name]: value }
|
||||
}
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get reasoning tag name based on model ID
|
||||
* Used for extractReasoningMiddleware configuration
|
||||
*/
|
||||
export function getReasoningTagName(modelId: string | undefined): string {
|
||||
const tagName = {
|
||||
reasoning: 'reasoning',
|
||||
think: 'think',
|
||||
thought: 'thought',
|
||||
seedThink: 'seed:think'
|
||||
}
|
||||
|
||||
if (modelId?.includes('gpt-oss')) return tagName.reasoning
|
||||
if (modelId?.includes('gemini')) return tagName.thought
|
||||
if (modelId?.includes('seed-oss-36b')) return tagName.seedThink
|
||||
return tagName.think
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
import type { AppProviderId } from '@renderer/aiCore/types'
|
||||
import { isOpenAIDeepResearchModel, isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||
|
||||
export function getWebSearchParams(model: Model): Record<string, any> {
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://creator.poe.com/docs/external-applications/openai-compatible-api#using-custom-parameters-with-extra_body
|
||||
if (model.provider === 'poe') {
|
||||
return {
|
||||
extra_body: {
|
||||
web_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* range in [0, 100]
|
||||
* @param maxResults
|
||||
*/
|
||||
function mapMaxResultToOpenAIContextSize(
|
||||
maxResults: number
|
||||
): NonNullable<WebSearchPluginConfig['openai']>['searchContextSize'] {
|
||||
if (maxResults <= 33) return 'low'
|
||||
if (maxResults <= 66) return 'medium'
|
||||
return 'high'
|
||||
}
|
||||
|
||||
export function buildProviderBuiltinWebSearchConfig(
|
||||
providerId: AppProviderId,
|
||||
webSearchConfig: CherryWebSearchConfig,
|
||||
model?: Model
|
||||
): WebSearchPluginConfig | undefined {
|
||||
switch (providerId) {
|
||||
case 'azure-responses':
|
||||
case 'openai': {
|
||||
const searchContextSize = isOpenAIDeepResearchModel(model)
|
||||
? 'medium'
|
||||
: mapMaxResultToOpenAIContextSize(webSearchConfig.maxResults)
|
||||
return {
|
||||
openai: {
|
||||
searchContextSize
|
||||
}
|
||||
}
|
||||
}
|
||||
case 'openai-chat': {
|
||||
const searchContextSize = isOpenAIDeepResearchModel(model)
|
||||
? 'medium'
|
||||
: mapMaxResultToOpenAIContextSize(webSearchConfig.maxResults)
|
||||
return {
|
||||
'openai-chat': {
|
||||
searchContextSize
|
||||
}
|
||||
}
|
||||
}
|
||||
case 'anthropic': {
|
||||
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
||||
const anthropicSearchOptions: NonNullable<WebSearchPluginConfig['anthropic']> = {
|
||||
maxUses: webSearchConfig.maxResults,
|
||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||
}
|
||||
return {
|
||||
anthropic: anthropicSearchOptions
|
||||
}
|
||||
}
|
||||
case 'xai':
|
||||
case 'xai-responses': {
|
||||
const excludeDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
||||
const xaiWebConfig: NonNullable<NonNullable<WebSearchPluginConfig['xai-responses']>['webSearch']> = {
|
||||
enableImageUnderstanding: true
|
||||
}
|
||||
if (excludeDomains.length > 0) {
|
||||
xaiWebConfig.excludedDomains = excludeDomains.slice(0, 5)
|
||||
}
|
||||
return {
|
||||
'xai-responses': {
|
||||
webSearch: xaiWebConfig,
|
||||
xSearch: { enableImageUnderstanding: true }
|
||||
}
|
||||
}
|
||||
}
|
||||
case 'openrouter': {
|
||||
return {
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: webSearchConfig.maxResults
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
case 'cherryin': {
|
||||
const _providerId =
|
||||
{ 'openai-response': 'openai', openai: 'openai-chat' }[model?.endpoint_type ?? ''] ?? model?.endpoint_type
|
||||
return buildProviderBuiltinWebSearchConfig(_providerId, webSearchConfig, model)
|
||||
}
|
||||
default: {
|
||||
return {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,6 +179,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||
|
||||
try {
|
||||
if (mode === 'aihubmix_image_generate') {
|
||||
// TODO(renderer/aiCore-cleanup): the remaining Gemini/Ideogram/custom fetch branches should move behind main AI/image IPC so this page no longer owns provider-specific transport logic.
|
||||
if (painting.model.startsWith('imagen-')) {
|
||||
const requestId = uuid()
|
||||
activeImageRequestIdRef.current = requestId
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/** TODO(renderer/aiCore-cleanup): replace these temporary mirrored tool response types with shared/main-owned contracts once knowledge/web/memory tools stop depending on legacy aiCore definitions. */
|
||||
export interface KnowledgeSearchToolInput {
|
||||
additionalContext?: string
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
isVertexProvider
|
||||
} from '@renderer/utils/provider'
|
||||
|
||||
/** TODO(renderer/aiCore-cleanup): converge this host normalization helper with the remaining main/provider config formatter so we can delete the old aiCore provider config copy safely. */
|
||||
interface HostFormatter {
|
||||
match: (provider: Provider) => boolean
|
||||
format: (provider: Provider, appendApiVersion: boolean) => string
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { findTokenLimit } from '@renderer/config/models'
|
||||
import { EFFORT_RATIO } from '@renderer/types'
|
||||
|
||||
/** TODO(renderer/aiCore-cleanup): only `getThinkingBudget` is extracted here. Migrate or delete the remaining reasoning helpers from the old renderer aiCore module after code/CLI flows are fully decoupled. */
|
||||
const FALLBACK_TOKEN_LIMIT = { min: 1024, max: 16384 }
|
||||
|
||||
function computeBudgetTokens(
|
||||
|
||||
Reference in New Issue
Block a user