refactor: migrate to ai sdk v6 Phase 3 (#12235)

Continued from #12227

## Summary

Phase 3 of AI SDK v6 migration:

- **Image generation**: Migrate to native AI SDK
`generateImage`/`editImage`, remove legacy image middleware
- **Embedding**: Migrate to AI SDK `embedMany`, remove legacy embedding
clients
- **Model listing**: Refactor `ModelListService` to Strategy Registry
pattern (`listModels.ts`), consolidate 7 schema files into one
`schemas.ts`
- **OpenRouter image**: Bump `@openrouter/ai-sdk-provider` to 2.3.3 with
native image endpoint support, remove `isNativeImageGenerationProvider`
guard
- **GitHub Copilot**: Simplify extension by removing `ProviderV2` cast
and `wrapProvider`
- **Legacy removal**: Delete all legacy API clients, middleware
pipeline, and barrel `index.ts`
- **Rename**: `index_new.ts` → `AiProvider.ts`, `ModelListService.ts` →
`listModels.ts`

## Manual Test Plan

### P0 — Core paths + PR change focus

#### 1. Core Chat
- [ ] OpenAI model (e.g. GPT-4o): send text, verify streaming output
- [ ] Anthropic model (e.g. Claude Sonnet): verify chat works
- [ ] Gemini model: verify chat works
- [ ] DeepSeek model: verify chat works
- [ ] Reasoning mode (o3-mini, DeepSeek R1): verify thinking process
displays
- [ ] Send message with image (multimodal): verify model recognizes
image
- [ ] Abort mid-stream: verify clean termination, no errors

#### 2. Image Generation (key change area)
- [ ] DALL-E 3 / GPT-Image-1: verify image generation works
- [ ] OpenRouter image model: verify **native image endpoint** is used
(no longer chat completions)
- [ ] Image editing: upload image + text prompt, verify editImage path
- [ ] Verify generated images display correctly (both base64 and URL
formats)

#### 3. Model Listing (refactored area)
- [ ] Sync OpenAI models: verify names and groups
- [ ] Sync Gemini models: verify `models/` prefix stripped
- [ ] Sync Ollama models (local): verify display
- [ ] Sync OpenRouter models: verify chat + embedding models both appear
- [ ] Sync SiliconFlow models: verify grouping (e.g. `deepseek-ai/`)
- [ ] Sync GitHub Models
- [ ] Sync Together models
- [ ] Sync NewAPI service (e.g. CherryIn)
- [ ] Sync PPIO models: verify chat + embedding + reranker endpoints
merged

### P1 — Affected by refactor

#### 4. Provider Extensions
- [ ] GitHub Copilot: verify chat works (simplified wrapProvider logic)
- [ ] Ollama: verify chat works
- [ ] Vertex AI / Bedrock: verify chat works (if configured)

#### 5. MCP Tool Calling
- [ ] Enable MCP server, send tool-requiring message, verify tool
execution
- [ ] Verify both prompt tool use and function calling modes

#### 6. Auxiliary Functions
- [ ] Auto topic title generation (fetchMessagesSummary)
- [ ] Note summary (fetchNoteSummary)
- [ ] Translation/generation (fetchGenerate)
- [ ] API check (click "Check" button in settings)

### P2 — Indirect impact

#### 7. Knowledge Base
- [ ] Create knowledge base, associate with assistant, verify RAG
retrieval and embedding

#### 8. Web Search
- [ ] Enable native web search (Gemini/Perplexity), verify search
results in response

#### 9. Tracing
- [ ] Enable developer mode, send message, verify trace spans recorded
and displayed

---------

Signed-off-by: suyao <sy20010504@gmail.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: icarus <eurfelux@gmail.com>
Co-authored-by: fullex <106392080+0xfullex@users.noreply.github.com>
This commit is contained in:
SuYao
2026-04-02 15:38:23 +08:00
committed by kangfenmao
parent 865a1be504
commit a58bbf52fd
205 changed files with 13915 additions and 30671 deletions

View File

@@ -0,0 +1,11 @@
---
'@cherrystudio/ai-core': minor
---
Remove unused exports, dead types, and over-engineered abstractions from aiCore
- Remove unused public exports: `createOpenAICompatibleExecutor`, `create*Options`, `mergeProviderOptions`, `PluginManager`, `createContext`, `AI_CORE_VERSION`, `AI_CORE_NAME`, `BUILT_IN_PLUGIN_PREFIX`, `registeredProviderIds`, `ProviderInitializationError`, `ProviderExtensionBuilder`, `createProviderExtension`
- Delete dead type definitions: `HookResult`, `PluginManagerConfig`, `AiRequestMetadata`, `ExtractProviderOptions`, `ProviderOptions`, `CoreProviderSettingsMap` (re-added as internal), `ExtractExtensionIds`, `ExtractExtensionSettings`
- Remove over-engineered `ExtensionStorage` system: delete `ExtensionStorage`, `StorageAccessor`, `ExtensionContext`, `ExtensionHook`, `LifecycleHooks` types; remove `TStorage` generic parameter from `ProviderExtension` (4 → 3 type params); remove `_storage`, `storage` getter, `createContext`, `executeHook`, `initialStorage`, `hooks` from class and config
- Delete `create*Options` convenience functions and inline `createOpenRouterOptions` at its only call site
- Delete `DEFAULT_WEB_SEARCH_CONFIG` and plugins `README.md`

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -38,7 +38,7 @@
"agents:drop": "NODE_ENV='development' drizzle-kit drop --config src/main/services/agents/drizzle.config.ts",
"analyze:renderer": "VISUALIZER_RENDERER=true pnpm build",
"analyze:main": "VISUALIZER_MAIN=true pnpm build",
"typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"",
"typecheck": "pnpm --filter @cherrystudio/ai-sdk-provider build && concurrently -n \"node,web,aicore\" -c \"cyan,magenta,yellow\" \"npm run typecheck:node\" \"npm run typecheck:web\" \"pnpm --filter @cherrystudio/ai-core typecheck\"",
"typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false",
"typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false",
"i18n:check": "dotenv -e .env -- tsx scripts/check-i18n.ts",
@@ -110,28 +110,28 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^4.0.67",
"@ai-sdk/anthropic": "^3.0.48",
"@ai-sdk/azure": "^3.0.37",
"@ai-sdk/cerebras": "^2.0.34",
"@ai-sdk/gateway": "^3.0.57",
"@ai-sdk/google": "^3.0.33",
"@ai-sdk/google-vertex": "^4.0.66",
"@ai-sdk/huggingface": "^1.0.32",
"@ai-sdk/mistral": "^3.0.20",
"@ai-sdk/openai": "^3.0.36",
"@ai-sdk/perplexity": "^3.0.19",
"@ai-sdk/amazon-bedrock": "^4.0.77",
"@ai-sdk/anthropic": "^3.0.58",
"@ai-sdk/azure": "^3.0.42",
"@ai-sdk/cerebras": "^2.0.39",
"@ai-sdk/cohere": "^3.0.25",
"@ai-sdk/gateway": "^3.0.66",
"@ai-sdk/google": "^3.0.43",
"@ai-sdk/google-vertex": "^4.0.80",
"@ai-sdk/huggingface": "^1.0.37",
"@ai-sdk/mistral": "^3.0.24",
"@ai-sdk/openai": "^3.0.41",
"@ai-sdk/openai-compatible": "^2.0.35",
"@ai-sdk/perplexity": "^3.0.23",
"@ai-sdk/provider": "^3.0.8",
"@ai-sdk/provider-utils": "^4.0.15",
"@ai-sdk/provider-utils": "^4.0.19",
"@ai-sdk/test-server": "^1.0.3",
"@ai-sdk/xai": "^3.0.59",
"@ai-sdk/togetherai": "^2.0.39",
"@ai-sdk/xai": "^3.0.67",
"@ant-design/cssinjs": "1.23.0",
"@ant-design/icons": "5.6.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "0.11.4",
"@aws-sdk/client-bedrock": "^3.998.0",
"@aws-sdk/client-bedrock-runtime": "^3.998.0",
"@aws-sdk/client-s3": "^3.998.0",
"@biomejs/biome": "2.2.4",
"@changesets/changelog-github": "^0.5.2",
@@ -171,7 +171,7 @@
"@eslint-react/eslint-plugin": "^1.36.1",
"@eslint/js": "^9.22.0",
"@floating-ui/dom": "1.7.3",
"@google/genai": "1.0.1",
"@google/genai": "^1.46.0",
"@hello-pangea/dnd": "^18.0.1",
"@iconify-json/material-icon-theme": "^1.2.56",
"@iconify/react": "^6.0.2",
@@ -185,7 +185,7 @@
"@modelcontextprotocol/sdk": "1.27.1",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^2.2.3",
"@openrouter/ai-sdk-provider": "^2.3.3",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/context-async-hooks": "2.0.1",
"@opentelemetry/core": "2.0.0",
@@ -273,7 +273,7 @@
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"adm-zip": "0.4.16",
"ai": "^6.0.103",
"ai": "^6.0.116",
"ansi-to-react": "^6.2.6",
"antd": "5.27.0",
"archiver": "^7.0.1",
@@ -430,6 +430,7 @@
"uuid": "^13.0.0",
"vite": "npm:rolldown-vite@7.3.0",
"vitest": "^3.2.4",
"voyage-ai-provider": "^3.0.0",
"webdav": "^5.9.0",
"winston": "^3.17.0",
"winston-daily-rotate-file": "^5.0.0",
@@ -473,8 +474,6 @@
"patchedDependencies": {
"@napi-rs/system-ocr@1.0.2": "patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"tesseract.js@6.0.1": "patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
"@anthropic-ai/vertex-sdk@0.11.4": "patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
"@google/genai@1.0.1": "patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
"@langchain/core@1.0.2": "patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
"@langchain/openai@1.0.0": "patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@tiptap/extension-drag-handle@3.2.0": "patches/@tiptap-extension-drag-handle-npm-3.2.0-5a9ebff7c9.patch",
@@ -484,9 +483,9 @@
"file-stream-rotator@0.6.1": "patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
"libsql@0.4.7": "patches/libsql-npm-0.4.7-444e260fb1.patch",
"pdf-parse@1.1.1": "patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
"@ai-sdk/openai-compatible@2.0.30": "patches/@ai-sdk__openai-compatible@2.0.30.patch",
"@openrouter/ai-sdk-provider": "patches/@openrouter__ai-sdk-provider.patch",
"ollama-ai-provider-v2@3.3.1": "patches/ollama-ai-provider-v2@3.3.1.patch",
"@ai-sdk/openai-compatible@2.0.35": "patches/@ai-sdk__openai-compatible@2.0.35.patch",
"@openrouter/ai-sdk-provider": "patches/@openrouter__ai-sdk-provider.patch",
"@opeoginni/github-copilot-openai-compatible@1.0.0": "patches/@opeoginni__github-copilot-openai-compatible@1.0.0.patch"
},
"onlyBuiltDependencies": [

View File

@@ -16,8 +16,8 @@
},
"type": "module",
"main": "dist/index.cjs",
"module": "dist/index.js",
"types": "dist/index.d.ts",
"module": "dist/index.mjs",
"types": "dist/index.d.mts",
"files": ["dist"],
"scripts": {
"build": "tsdown",
@@ -28,14 +28,14 @@
"prepublishOnly": "pnpm build"
},
"peerDependencies": {
"@ai-sdk/anthropic": "^3.0.48",
"@ai-sdk/google": "^3.0.33",
"@ai-sdk/openai": "^3.0.36",
"@ai-sdk/openai-compatible": "^2.0.30"
"@ai-sdk/anthropic": "^3.0.58",
"@ai-sdk/google": "^3.0.43",
"@ai-sdk/openai": "^3.0.41",
"@ai-sdk/openai-compatible": "^2.0.35"
},
"dependencies": {
"@ai-sdk/provider": "^3.0.8",
"@ai-sdk/provider-utils": "^4.0.15"
"@ai-sdk/provider-utils": "^4.0.19"
},
"devDependencies": {
"tsdown": "^0.20.3",
@@ -48,10 +48,10 @@
},
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js",
"types": "./dist/index.d.mts",
"import": "./dist/index.mjs",
"require": "./dist/index.cjs",
"default": "./dist/index.js"
"default": "./dist/index.mjs"
}
}
}

View File

@@ -1,514 +0,0 @@
# AI Core 基于 Vercel AI SDK 的技术架构
## 1. 架构设计理念
### 1.1 设计目标
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
- **动态导入**:通过动态导入实现按需加载,减少打包体积
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
- **轻量级**:专注核心功能,保持包的轻量和高效
- **包级独立**:作为独立包管理,便于复用和维护
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
### 1.2 核心优势
- **标准化**AI SDK 提供统一的模型接口,减少适配工作
- **简化设计**函数式API避免过度抽象
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
- **性能优化**AI SDK 内置优化和最佳实践
- **模块化设计**:独立包结构,支持跨项目复用
- **可扩展插件**:通用的流转换和参数处理插件系统
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
## 2. 整体架构图
```mermaid
graph TD
subgraph "用户应用 (如 Cherry Studio)"
UI["用户界面"]
Components["应用组件"]
end
subgraph "packages/aiCore (AI Core 包)"
subgraph "Runtime Layer (运行时层)"
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
PluginEngine["PluginEngine (插件引擎)"]
RuntimeAPI["Runtime API (便捷函数)"]
end
subgraph "Models Layer (模型层)"
ModelFactory["createModel() (模型工厂)"]
ProviderCreator["ProviderCreator (提供商创建器)"]
end
subgraph "Core Systems (核心系统)"
subgraph "Plugins (插件)"
PluginManager["PluginManager (插件管理)"]
BuiltInPlugins["Built-in Plugins (内置插件)"]
StreamTransforms["Stream Transforms (流转换)"]
end
subgraph "Middleware (中间件)"
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
end
subgraph "Providers (提供商)"
Registry["Provider Registry (注册表)"]
Factory["Provider Factory (工厂)"]
end
end
end
subgraph "Vercel AI SDK"
AICore["ai (核心库)"]
OpenAI["@ai-sdk/openai"]
Anthropic["@ai-sdk/anthropic"]
Google["@ai-sdk/google"]
XAI["@ai-sdk/xai"]
Others["其他 19+ Providers"]
end
subgraph "Future: OpenAI Agents SDK"
AgentSDK["@openai/agents (未来集成)"]
AgentExtensions["Agent Extensions (预留)"]
end
UI --> RuntimeAPI
Components --> RuntimeExecutor
RuntimeAPI --> RuntimeExecutor
RuntimeExecutor --> PluginEngine
RuntimeExecutor --> ModelFactory
PluginEngine --> PluginManager
ModelFactory --> ProviderCreator
ModelFactory --> MiddlewareWrapper
ProviderCreator --> Registry
Registry --> Factory
Factory --> OpenAI
Factory --> Anthropic
Factory --> Google
Factory --> XAI
Factory --> Others
RuntimeExecutor --> AICore
AICore --> streamText
AICore --> generateText
AICore --> streamObject
AICore --> generateObject
PluginManager --> StreamTransforms
PluginManager --> BuiltInPlugins
%% 未来集成路径
RuntimeExecutor -.-> AgentSDK
AgentSDK -.-> AgentExtensions
```
## 3. 包结构设计
### 3.1 新架构文件结构
```
packages/aiCore/
├── src/
│ ├── core/ # 核心层 - 内部实现
│ │ ├── models/ # 模型层 - 模型创建和配置
│ │ │ ├── factory.ts # 模型工厂函数 ✅
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
│ │ │ ├── types.ts # 模型类型定义 ✅
│ │ │ └── index.ts # 模型层导出 ✅
│ │ ├── runtime/ # 运行时层 - 执行和用户API
│ │ │ ├── executor.ts # 运行时执行器 ✅
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
│ │ │ ├── types.ts # 运行时类型定义 ✅
│ │ │ └── index.ts # 运行时导出 ✅
│ │ ├── middleware/ # 中间件系统
│ │ │ ├── wrapper.ts # 模型包装器 ✅
│ │ │ ├── manager.ts # 中间件管理器 ✅
│ │ │ ├── types.ts # 中间件类型 ✅
│ │ │ └── index.ts # 中间件导出 ✅
│ │ ├── plugins/ # 插件系统
│ │ │ ├── types.ts # 插件类型定义 ✅
│ │ │ ├── manager.ts # 插件管理器 ✅
│ │ │ ├── built-in/ # 内置插件 ✅
│ │ │ │ ├── logging.ts # 日志插件 ✅
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
│ │ │ │ └── index.ts # 内置插件导出 ✅
│ │ │ ├── README.md # 插件文档 ✅
│ │ │ └── index.ts # 插件导出 ✅
│ │ ├── providers/ # 提供商管理
│ │ │ ├── registry.ts # 提供商注册表 ✅
│ │ │ ├── factory.ts # 提供商工厂 ✅
│ │ │ ├── creator.ts # 提供商创建器 ✅
│ │ │ ├── types.ts # 提供商类型 ✅
│ │ │ ├── utils.ts # 工具函数 ✅
│ │ │ └── index.ts # 提供商导出 ✅
│ │ ├── options/ # 配置选项
│ │ │ ├── factory.ts # 选项工厂 ✅
│ │ │ ├── types.ts # 选项类型 ✅
│ │ │ ├── xai.ts # xAI 选项 ✅
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
│ │ │ ├── examples.ts # 示例配置 ✅
│ │ │ └── index.ts # 选项导出 ✅
│ │ └── index.ts # 核心层导出 ✅
│ ├── types.ts # 全局类型定义 ✅
│ └── index.ts # 包主入口文件 ✅
├── package.json # 包配置文件 ✅
├── tsconfig.json # TypeScript 配置 ✅
├── README.md # 包说明文档 ✅
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
```
## 4. 架构分层详解
### 4.1 Models Layer (模型层)
**职责**:统一的模型创建和配置管理
**核心文件**
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
- `types.ts`: 模型配置类型定义
**设计特点**
- 函数式设计,避免不必要的类抽象
- 统一的模型配置接口
- 自动处理中间件应用
- 支持批量模型创建
**核心API**
```typescript
// 模型配置接口
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV1Middleware[]
}
// 核心模型创建函数
export async function createModel(config: ModelConfig): Promise<LanguageModel>
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
```
### 4.2 Runtime Layer (运行时层)
**职责**运行时执行器和用户面向的API接口
**核心组件**
- `executor.ts`: 运行时执行器类
- `plugin-engine.ts`: 插件引擎原PluginEnabledAiClient
- `index.ts`: 便捷函数和工厂方法
**设计特点**
- 提供三种使用方式:类实例、静态工厂、函数式调用
- 自动集成模型创建和插件处理
- 完整的类型安全支持
- 为 OpenAI Agents SDK 预留扩展接口
**核心API**
```typescript
// 运行时执行器
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
static create<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): RuntimeExecutor<T>
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
}
// 便捷函数式API
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: StreamTextParams,
plugins?: AiPlugin[]
): Promise<StreamTextResult>
```
### 4.3 Plugin System (插件系统)
**职责**:可扩展的插件架构
**核心组件**
- `PluginManager`: 插件生命周期管理
- `built-in/`: 内置插件集合
- 流转换收集和应用
**设计特点**
- 借鉴 Rollup 的钩子分类设计
- 支持流转换 (`experimental_transform`)
- 内置常用插件(日志、计数等)
- 完整的生命周期钩子
**插件接口**
```typescript
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理
transformStream?: () => TransformStream
}
```
### 4.4 Middleware System (中间件系统)
**职责**AI SDK原生中间件支持
**核心组件**
- `ModelWrapper.ts`: 模型包装函数
**设计哲学**
- 直接使用AI SDK的 `wrapLanguageModel`
- 与插件系统分离,职责明确
- 函数式设计,简化使用
```typescript
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
```
### 4.5 Provider System (提供商系统)
**职责**AI Provider注册表和动态导入
**核心组件**
- `registry.ts`: 19+ Provider配置和类型
- `factory.ts`: Provider配置工厂
**支持的Providers**
- OpenAI, Anthropic, Google, XAI
- Azure OpenAI, Amazon Bedrock, Google Vertex
- Groq, Together.ai, Fireworks, DeepSeek
- 等19+ AI SDK官方支持的providers
## 5. 使用方式
### 5.1 函数式调用 (推荐 - 简单场景)
```typescript
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
// 直接函数调用
const stream = await streamText(
'anthropic',
{ apiKey: 'your-api-key' },
'claude-3',
{ messages: [{ role: 'user', content: 'Hello!' }] },
[loggingPlugin]
)
```
### 5.2 执行器实例 (推荐 - 复杂场景)
```typescript
import { createExecutor } from '@cherrystudio/ai-core/runtime'
// 创建可复用的执行器
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
// 多次使用
const stream = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
const result = await executor.generateText('gpt-4', {
messages: [{ role: 'user', content: 'How are you?' }]
})
```
### 5.3 静态工厂方法
```typescript
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
// 静态创建
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
await executor.streamText('claude-3', { messages: [...] })
```
### 5.4 直接模型创建 (高级用法)
```typescript
import { createModel } from '@cherrystudio/ai-core/models'
import { streamText } from 'ai'
// 直接创建模型使用
const model = await createModel({
providerId: 'openai',
modelId: 'gpt-4',
options: { apiKey: 'your-api-key' },
middlewares: [middleware1, middleware2]
})
// 直接使用 AI SDK
const result = await streamText({ model, messages: [...] })
```
## 6. 为 OpenAI Agents SDK 预留的设计
### 6.1 架构兼容性
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
```typescript
// 当前的模型创建
const model = await createModel({
providerId: 'anthropic',
modelId: 'claude-3',
options: { apiKey: 'xxx' }
})
// 将来可以直接用于 OpenAI Agents SDK
import { Agent, run } from '@openai/agents'
const agent = new Agent({
model, // ✅ 直接兼容 LanguageModel 接口
name: 'Assistant',
instructions: '...',
tools: [tool1, tool2]
})
const result = await run(agent, 'user input')
```
### 6.2 预留的扩展点
1. **runtime/agents/** 目录预留
2. **AgentExecutor** 类预留
3. **Agent工具转换插件** 预留
4. **多Agent编排** 预留
### 6.3 未来架构扩展
```
packages/aiCore/src/core/
├── runtime/
│ ├── agents/ # 🚀 未来添加
│ │ ├── AgentExecutor.ts
│ │ ├── WorkflowManager.ts
│ │ └── ConversationManager.ts
│ ├── executor.ts
│ └── index.ts
```
## 7. 架构优势
### 7.1 简化设计
- **移除过度抽象**删除了orchestration层和creation层的复杂包装
- **函数式优先**models层使用函数而非类
- **直接明了**runtime层直接提供用户API
### 7.2 职责清晰
- **Models**: 专注模型创建和配置
- **Runtime**: 专注执行和用户API
- **Plugins**: 专注扩展功能
- **Providers**: 专注AI Provider管理
### 7.3 类型安全
- 完整的 TypeScript 支持
- AI SDK 类型的直接复用
- 避免类型重复定义
### 7.4 灵活使用
- 三种使用模式满足不同需求
- 从简单函数调用到复杂执行器
- 支持直接AI SDK使用
### 7.5 面向未来
- 为 OpenAI Agents SDK 集成做好准备
- 清晰的扩展点和架构边界
- 模块化设计便于功能添加
## 8. 技术决策记录
### 8.1 为什么选择简化的两层架构?
- **职责分离**models专注创建runtime专注执行
- **模块化**:每层都有清晰的边界和职责
- **扩展性**为Agent功能预留了清晰的扩展空间
### 8.2 为什么选择函数式设计?
- **简洁性**:避免不必要的类设计
- **性能**:减少对象创建开销
- **易用性**:函数调用更直观
### 8.3 为什么分离插件和中间件?
- **职责明确**: 插件处理应用特定需求
- **原生支持**: 中间件使用AI SDK原生功能
- **灵活性**: 两套系统可以独立演进
## 9. 总结
AI Core架构实现了
### 9.1 核心特点
-**简化架构**: 2层核心架构职责清晰
-**函数式设计**: models层完全函数化
-**类型安全**: 统一的类型定义和AI SDK类型复用
-**插件扩展**: 强大的插件系统
-**多种使用方式**: 满足不同复杂度需求
-**Agent就绪**: 为OpenAI Agents SDK集成做好准备
### 9.2 核心价值
- **统一接口**: 一套API支持19+ AI providers
- **灵活使用**: 函数式、实例式、静态工厂式
- **强类型**: 完整的TypeScript支持
- **可扩展**: 插件和中间件双重扩展能力
- **高性能**: 最小化包装直接使用AI SDK
- **面向未来**: Agent SDK集成架构就绪
### 9.3 未来发展
这个架构提供了:
- **优秀的开发体验**: 简洁的API和清晰的使用模式
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
- **良好的维护性**: 职责分离明确,代码易于维护
- **广泛的适用性**: 既适合简单调用也适合复杂应用

View File

@@ -10,6 +10,10 @@
"build": "tsdown",
"dev": "tsc -w",
"typecheck": "tsc --noEmit",
"lint": "oxlint --fix && eslint . --fix --cache",
"lint:check": "oxlint && eslint . --cache",
"format": "biome format --write && biome lint --write",
"format:check": "biome format && biome lint",
"clean": "rm -rf dist",
"test": "vitest run",
"test:watch": "vitest",
@@ -27,23 +31,27 @@
},
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
"peerDependencies": {
"@ai-sdk/google": "^3.0.33",
"@ai-sdk/openai": "^3.0.36",
"ai": "^6.0.103"
"@ai-sdk/google": "^3.0.43",
"@ai-sdk/openai": "^3.0.41",
"ai": "^6.0.116"
},
"dependencies": {
"@cherrystudio/ai-sdk-provider": "workspace:*",
"@ai-sdk/anthropic": "^3.0.48",
"@ai-sdk/azure": "^3.0.37",
"@ai-sdk/deepseek": "^2.0.20",
"@ai-sdk/openai-compatible": "^2.0.30",
"@ai-sdk/anthropic": "^3.0.58",
"@ai-sdk/azure": "^3.0.42",
"@ai-sdk/deepseek": "^2.0.24",
"@ai-sdk/openai-compatible": "^2.0.35",
"@openrouter/ai-sdk-provider": "^2.3.3",
"@ai-sdk/provider": "^3.0.8",
"@ai-sdk/provider-utils": "^4.0.15",
"@ai-sdk/xai": "^3.0.59",
"@openrouter/ai-sdk-provider": "^2.2.3",
"@ai-sdk/provider-utils": "^4.0.19",
"@ai-sdk/xai": "^3.0.67",
"lru-cache": "^11.2.4",
"zod": "^4.1.5"
},
"devDependencies": {
"@cherrystudio/ai-sdk-provider": "workspace:*",
"biome": "^0.3.3",
"oxlint": "^1.36.0",
"tsdown": "^0.20.3",
"typescript": "^5.0.0",
"vitest": "^3.2.4"

View File

@@ -1,13 +0,0 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/model-test-utils'
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@@ -1,3 +0,0 @@
# @cherryStudio-aiCore
Core

View File

@@ -2,13 +2,7 @@
* Core 模块导出
* 内部核心功能,供其他模块使用,不直接面向最终调用者
*/
// 中间件系统
export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理
export { globalModelResolver, ModelResolver } from './models'
// 模型类型
export type { ModelConfig as ModelConfigType } from './models/types'
// 执行管理

View File

@@ -1,8 +0,0 @@
/**
* Middleware 模块导出
* 提供通用的中间件管理能力
*/
export { createMiddlewares } from './manager'
export type { NamedMiddleware } from './types'
export { wrapModelWithMiddlewares } from './wrapper'

View File

@@ -1,16 +0,0 @@
/**
* 中间件管理器
* 专注于 AI SDK 中间件的管理,与插件系统分离
*/
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
/**
* 创建中间件列表
* 合并用户提供的中间件
*/
export function createMiddlewares(userMiddlewares: LanguageModelV3Middleware[] = []): LanguageModelV3Middleware[] {
// 未来可以在这里添加默认的中间件
const defaultMiddlewares: LanguageModelV3Middleware[] = []
return [...defaultMiddlewares, ...userMiddlewares]
}

View File

@@ -1,12 +0,0 @@
/**
* 中间件系统类型定义
*/
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
/**
* 具名中间件接口
*/
export interface NamedMiddleware {
name: string
middleware: LanguageModelV3Middleware
}

View File

@@ -1,23 +0,0 @@
/**
* 模型包装工具函数
* 用于将中间件应用到LanguageModel上
*/
import type { LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider'
import { wrapLanguageModel } from 'ai'
/**
* 使用中间件包装模型
*/
export function wrapModelWithMiddlewares(
model: LanguageModelV3,
middlewares: LanguageModelV3Middleware[]
): LanguageModelV3 {
if (middlewares.length === 0) {
return model
}
return wrapLanguageModel({
model,
middleware: middlewares
})
}

View File

@@ -1,114 +0,0 @@
/**
* 模型解析器 - models模块的核心
* 负责将modelId解析为AI SDK的LanguageModel实例
* 支持传统格式和命名空间格式
* 集成了来自 ModelCreator 的特殊处理逻辑
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
export class ModelResolver {
/**
* 核心方法解析任意格式的modelId为语言模型
*
* @param modelId 模型ID支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
* @param providerOptions provider配置选项用于OpenAI模式选择等
*/
async resolveLanguageModel(
modelId: string,
fallbackProviderId: string,
providerOptions?: any
): Promise<LanguageModelV3> {
let finalProviderId = fallbackProviderId
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
finalProviderId = `${fallbackProviderId}-chat`
}
// 检查是否是命名空间格式
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedModel(modelId)
} else {
// 传统格式:使用处理后的 providerId + modelId
return this.resolveTraditionalModel(finalProviderId, modelId)
}
}
/**
* 解析文本嵌入模型
*/
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV3> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedEmbeddingModel(modelId)
}
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
}
/**
* 解析图像模型
*/
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV3> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedImageModel(modelId)
}
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
}
/**
* 解析命名空间格式的语言模型
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
*/
private resolveNamespacedModel(modelId: string): LanguageModelV3 {
return globalRegistryManagement.languageModel(modelId as any)
}
/**
* 解析传统格式的语言模型
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV3 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.languageModel(fullModelId as any)
}
/**
* 解析命名空间格式的嵌入模型
*/
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV3 {
return globalRegistryManagement.embeddingModel(modelId as any)
}
/**
* 解析传统格式的嵌入模型
*/
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV3 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.embeddingModel(fullModelId as any)
}
/**
* 解析命名空间格式的图像模型
*/
private resolveNamespacedImageModel(modelId: string): ImageModelV3 {
return globalRegistryManagement.imageModel(modelId as any)
}
/**
* 解析传统格式的图像模型
*/
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV3 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.imageModel(fullModelId as any)
}
}
/**
* 全局模型解析器实例
*/
export const globalModelResolver = new ModelResolver()

View File

@@ -1,47 +1,29 @@
/**
* ModelResolver Comprehensive Tests
* Tests model resolution logic for language, embedding, and image models
* Covers both traditional and namespaced format resolution
* ProviderRegistry Model Resolution Tests
* Tests model resolution via AI SDK's createProviderRegistry
* The registry routes 'providerId:modelId' to the correct provider
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import {
createMockEmbeddingModel,
createMockImageModel,
createMockLanguageModel,
createMockProviderV3
} from '@test-utils'
import { createProviderRegistry } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../../providers/RegistryManagement'
import { ModelResolver } from '../ModelResolver'
// Mock the dependencies
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn(),
embeddingModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('../../middleware/wrapper', () => ({
wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => {
// Return a wrapped model with a marker
return {
...model,
_wrapped: true
} as LanguageModelV3
})
}))
describe('ModelResolver', () => {
let resolver: ModelResolver
describe('ProviderRegistry Model Resolution', () => {
let registry: ReturnType<typeof createProviderRegistry>
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
resolver = new ModelResolver()
// Create properly typed mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test-provider',
modelId: 'test-model'
@@ -57,361 +39,168 @@ describe('ModelResolver', () => {
modelId: 'test-image'
})
// Setup default mock implementations
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(globalRegistryManagement.embeddingModel).mockReturnValue(mockEmbeddingModel)
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
})
describe('resolveLanguageModel', () => {
describe('Traditional Format Resolution', () => {
it('should resolve traditional format modelId without separator', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(result).toBe(mockLanguageModel)
})
it('should resolve with different provider and modelId combinations', async () => {
const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [
{ modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' },
{ modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' },
{ modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' },
{ modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' }
]
for (const testCase of testCases) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(testCase.expected)
}
})
it('should handle modelIds with special characters', async () => {
const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free']
for (const modelId of modelIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(modelId, 'provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`provider${DEFAULT_SEPARATOR}${modelId}`)
}
})
mockProvider = createMockProviderV3({
provider: 'test-provider',
languageModel: vi.fn(() => mockLanguageModel),
embeddingModel: vi.fn(() => mockEmbeddingModel),
imageModel: vi.fn(() => mockImageModel)
})
describe('Namespaced Format Resolution', () => {
it('should resolve namespaced format with hub', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet`
const result = await resolver.resolveLanguageModel(namespacedId, 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockLanguageModel)
})
it('should resolve simple namespaced format', async () => {
const namespacedId = `provider${DEFAULT_SEPARATOR}model-id`
await resolver.resolveLanguageModel(namespacedId, 'fallback-provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
})
it('should handle complex namespaced IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`,
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`,
`custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveLanguageModel(id, 'fallback')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(id)
}
})
})
describe('OpenAI Mode Selection', () => {
it('should append "-chat" suffix for OpenAI provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4')
})
it('should append "-chat" suffix for Azure provider with chat mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4')
})
it('should not append suffix for OpenAI with responses mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for OpenAI without mode', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should not append suffix for other providers with chat mode', async () => {
await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' })
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3')
})
it('should handle namespaced IDs with OpenAI chat mode', async () => {
const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4`
await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' })
// Should use the namespaced ID directly, not apply mode logic
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
})
})
describe('Provider Options Handling', () => {
it('should pass provider options correctly', async () => {
const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' }
await resolver.resolveLanguageModel('gpt-4', 'openai', options)
// Provider options are used for mode selection logic
expect(globalRegistryManagement.languageModel).toHaveBeenCalled()
})
it('should handle empty provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', {})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
it('should handle undefined provider options', async () => {
await resolver.resolveLanguageModel('gpt-4', 'openai', undefined)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
})
registry = createProviderRegistry({
'test-provider': mockProvider
})
})
describe('resolveTextEmbeddingModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional embedding model ID', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
describe('languageModel', () => {
it('should resolve modelId via provider registry', () => {
const result = registry.languageModel('test-provider:gpt-4')
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002')
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve different embedding models', async () => {
const testCases = [
{ modelId: 'text-embedding-3-small', providerId: 'openai' },
{ modelId: 'text-embedding-3-large', providerId: 'openai' },
{ modelId: 'embed-english-v3.0', providerId: 'cohere' },
{ modelId: 'voyage-2', providerId: 'voyage' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(modelId, providerId)
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
describe('Namespaced Format', () => {
it('should resolve namespaced embedding model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small`
it('should pass various modelIds to the correct provider', () => {
const modelIds = [
'claude-3-5-sonnet',
'gemini-2.0-flash',
'grok-2-latest',
'deepseek-chat',
'model-v1.0',
'model_v2',
'model.2024'
]
const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai')
for (const modelId of modelIds) {
vi.clearAllMocks()
registry.languageModel(`test-provider:${modelId}`)
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockEmbeddingModel)
})
it('should handle complex namespaced embedding IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveTextEmbeddingModel(id, 'fallback')
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('resolveImageModel', () => {
describe('Traditional Format', () => {
it('should resolve traditional image model ID', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should resolve different image models', async () => {
const testCases = [
{ modelId: 'dall-e-2', providerId: 'openai' },
{ modelId: 'stable-diffusion-xl', providerId: 'stability' },
{ modelId: 'imagen-2', providerId: 'google' },
{ modelId: 'midjourney-v6', providerId: 'midjourney' }
]
for (const { modelId, providerId } of testCases) {
vi.clearAllMocks()
await resolver.resolveImageModel(modelId, providerId)
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
}
})
expect(mockProvider.languageModel).toHaveBeenCalledWith(modelId)
}
})
describe('Namespaced Format', () => {
it('should resolve namespaced image model ID', async () => {
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3`
const result = await resolver.resolveImageModel(namespacedId, 'openai')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(namespacedId)
expect(result).toBe(mockImageModel)
})
it('should handle complex namespaced image IDs', async () => {
const complexIds = [
`hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`,
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2`
]
for (const id of complexIds) {
vi.clearAllMocks()
await resolver.resolveImageModel(id, 'fallback')
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(id)
}
})
})
})
describe('Edge Cases and Error Scenarios', () => {
it('should handle empty model IDs', async () => {
await resolver.resolveLanguageModel('', 'openai')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|')
})
it('should handle model IDs with multiple separators', async () => {
const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`
await resolver.resolveLanguageModel(multiSeparatorId, 'fallback')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(multiSeparatorId)
})
it('should handle model IDs with only separator', async () => {
const onlySeparator = DEFAULT_SEPARATOR
await resolver.resolveLanguageModel(onlySeparator, 'provider')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(onlySeparator)
})
it('should throw if globalRegistryManagement throws', async () => {
const error = new Error('Model not found in registry')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
it('should throw if provider throws', () => {
const error = new Error('Model not found')
vi.mocked(mockProvider.languageModel).mockImplementation(() => {
throw error
})
await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow(
'Model not found in registry'
)
expect(() => registry.languageModel('test-provider:invalid-model')).toThrow('Model not found')
})
it('should handle concurrent resolution requests', async () => {
const promises = [
resolver.resolveLanguageModel('gpt-4', 'openai'),
resolver.resolveLanguageModel('claude-3', 'anthropic'),
resolver.resolveLanguageModel('gemini-2.0', 'google')
it('should handle concurrent resolution requests', () => {
const results = [
registry.languageModel('test-provider:gpt-4'),
registry.languageModel('test-provider:claude-3'),
registry.languageModel('test-provider:gemini-2.0')
]
const results = await Promise.all(promises)
expect(results).toHaveLength(3)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledTimes(3)
expect(mockProvider.languageModel).toHaveBeenCalledTimes(3)
})
it('should throw for unknown provider', () => {
expect(() => registry.languageModel('unknown:gpt-4' as `${string}:${string}`)).toThrow()
})
})
describe('embeddingModel', () => {
it('should resolve embedding model ID', () => {
const result = registry.embeddingModel('test-provider:text-embedding-ada-002')
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-ada-002')
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve different embedding models', () => {
const modelIds = ['text-embedding-3-small', 'text-embedding-3-large', 'embed-english-v3.0', 'voyage-2']
for (const modelId of modelIds) {
vi.clearAllMocks()
registry.embeddingModel(`test-provider:${modelId}`)
expect(mockProvider.embeddingModel).toHaveBeenCalledWith(modelId)
}
})
})
describe('imageModel', () => {
it('should resolve image model ID', () => {
const result = registry.imageModel('test-provider:dall-e-3')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should resolve different image models', () => {
const modelIds = ['dall-e-2', 'stable-diffusion-xl', 'imagen-2', 'grok-2-image']
for (const modelId of modelIds) {
vi.clearAllMocks()
registry.imageModel(`test-provider:${modelId}`)
expect(mockProvider.imageModel).toHaveBeenCalledWith(modelId)
}
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', async () => {
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
it('should return properly typed LanguageModelV3', () => {
const result = registry.languageModel('test-provider:gpt-4')
// Type assertions
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', async () => {
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
it('should return properly typed EmbeddingModelV3', () => {
const result = registry.embeddingModel('test-provider:text-embedding-ada-002')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', async () => {
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
it('should return properly typed ImageModelV3', () => {
const result = registry.imageModel('test-provider:dall-e-3')
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
})
})
describe('Global ModelResolver Instance', () => {
it('should have a global instance available', async () => {
const { globalModelResolver } = await import('../ModelResolver')
describe('Multi-provider registry', () => {
it('should route to correct provider in multi-provider registry', () => {
const mockProvider2 = createMockProviderV3({
provider: 'second-provider',
languageModel: vi.fn(() =>
createMockLanguageModel({
provider: 'second-provider',
modelId: 'other-model'
})
)
})
expect(globalModelResolver).toBeInstanceOf(ModelResolver)
const multiRegistry = createProviderRegistry({
first: mockProvider,
second: mockProvider2
})
multiRegistry.languageModel('first:gpt-4')
multiRegistry.languageModel('second:other-model')
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockProvider2.languageModel).toHaveBeenCalledWith('other-model')
})
})
describe('Integration with Different Provider Types', () => {
it('should work with OpenAI compatible providers', async () => {
await resolver.resolveLanguageModel('custom-model', 'openai-compatible')
describe('All model types for same provider', () => {
it('should handle all model types correctly', () => {
registry.languageModel('test-provider:gpt-4')
registry.embeddingModel('test-provider:text-embedding-3-small')
registry.imageModel('test-provider:dall-e-3')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model')
})
it('should work with hub providers', async () => {
const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1`
await resolver.resolveLanguageModel(hubId, 'aihubmix')
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(hubId)
})
it('should handle all model types for same provider', async () => {
const providerId = 'openai'
const languageModel = 'gpt-4'
const embeddingModel = 'text-embedding-3-small'
const imageModel = 'dall-e-3'
await resolver.resolveLanguageModel(languageModel, providerId)
await resolver.resolveTextEmbeddingModel(embeddingModel, providerId)
await resolver.resolveImageModel(imageModel, providerId)
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`)
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`)
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`)
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
})
})

View File

@@ -2,9 +2,6 @@
* Models 模块统一导出 - 简化版
*/
// 核心模型解析器
export { globalModelResolver, ModelResolver } from './ModelResolver'
// 保留的类型定义(可能被其他地方使用)
export type { ModelConfig as ModelConfigType } from './types'

View File

@@ -3,12 +3,21 @@
*/
import type { JSONObject, LanguageModelV3Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
import type { CoreProviderSettingsMap, ProviderId } from '../providers/types'
export interface ModelConfig<T extends ProviderId = ProviderId> {
/**
* 模型配置
*
* @typeParam T - Provider ID 类型
* @typeParam TSettingsMap - Provider Settings Map默认 CoreProviderSettingsMap
*/
export interface ModelConfig<
T extends ProviderId = ProviderId,
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap
> {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
providerSettings: TSettingsMap[T & keyof TSettingsMap]
middlewares?: LanguageModelV3Middleware[]
extraModelConfig?: JSONObject
}

View File

@@ -1,109 +1,87 @@
import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
import { describe, expect, it } from 'vitest'
import { createOpenAIOptions, createOpenRouterOptions, mergeProviderOptions } from '../factory'
import { mergeProviderOptions } from '../factory'
import type { TypedProviderOptions } from '../types'
// Helper to build typed options for tests without verbose casts at each call site
const opts = (o: Record<string, Record<string, unknown>>): Partial<TypedProviderOptions> =>
o as Partial<TypedProviderOptions>
describe('mergeProviderOptions', () => {
it('deep merges provider options for the same provider', () => {
const reasoningOptions = createOpenRouterOptions({
reasoning: {
enabled: true,
effort: 'medium'
}
})
const webSearchOptions = createOpenRouterOptions({
plugins: [{ id: 'web', max_results: 5 }]
})
const reasoningOptions: Partial<TypedProviderOptions> = {
openrouter: { reasoning: { enabled: true, effort: 'medium' } } as OpenRouterProviderOptions
}
const webSearchOptions = opts({ openrouter: { plugins: [{ id: 'web', max_results: 5 }] } })
const merged = mergeProviderOptions(reasoningOptions, webSearchOptions)
expect(merged.openrouter).toEqual({
reasoning: {
enabled: true,
effort: 'medium'
},
reasoning: { enabled: true, effort: 'medium' },
plugins: [{ id: 'web', max_results: 5 }]
})
})
it('preserves options from other providers while merging', () => {
const openRouter = createOpenRouterOptions({
reasoning: { enabled: true }
})
const openAI = createOpenAIOptions({
reasoningEffort: 'low'
})
const openRouter: Partial<TypedProviderOptions> = {
openrouter: { reasoning: { enabled: true, effort: 'medium' } } as OpenRouterProviderOptions
}
const openAI: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'low' } }
const merged = mergeProviderOptions(openRouter, openAI)
expect(merged.openrouter).toEqual({ reasoning: { enabled: true } })
expect(merged.openrouter).toEqual({ reasoning: { enabled: true, effort: 'medium' } })
expect(merged.openai).toEqual({ reasoningEffort: 'low' })
})
it('overwrites primitive values with later values', () => {
const first = createOpenAIOptions({
reasoningEffort: 'low',
user: 'user-123'
})
const second = createOpenAIOptions({
reasoningEffort: 'high',
maxToolCalls: 5
})
const first: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'low', user: 'user-123' } }
const second: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'high', maxToolCalls: 5 } }
const merged = mergeProviderOptions(first, second)
expect(merged.openai).toEqual({
reasoningEffort: 'high', // overwritten by second
user: 'user-123', // preserved from first
maxToolCalls: 5 // added from second
reasoningEffort: 'high',
user: 'user-123',
maxToolCalls: 5
})
})
it('overwrites arrays with later values instead of merging', () => {
const first = createOpenRouterOptions({
models: ['gpt-4', 'gpt-3.5-turbo']
})
const second = createOpenRouterOptions({
models: ['claude-3-opus', 'claude-3-sonnet']
})
const first = opts({ openrouter: { models: ['gpt-4', 'gpt-3.5-turbo'] } })
const second = opts({ openrouter: { models: ['claude-3-opus', 'claude-3-sonnet'] } })
const merged = mergeProviderOptions(first, second)
// Array is completely replaced, not merged
expect(merged.openrouter?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
expect((merged.openrouter as Record<string, unknown>)?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
})
it('deeply merges nested objects while overwriting primitives', () => {
const first = createOpenRouterOptions({
reasoning: {
enabled: true,
effort: 'low'
},
user: 'user-123'
const first = opts({
openrouter: {
reasoning: { enabled: true, effort: 'low' },
user: 'user-123'
}
})
const second = createOpenRouterOptions({
reasoning: {
effort: 'high',
max_tokens: 500
},
user: 'user-456'
const second = opts({
openrouter: {
reasoning: { effort: 'high', max_tokens: 500 },
user: 'user-456'
}
})
const merged = mergeProviderOptions(first, second)
expect(merged.openrouter).toEqual({
reasoning: {
enabled: true, // preserved from first
effort: 'high', // overwritten by second
max_tokens: 500 // added from second
},
user: 'user-456' // overwritten by second
reasoning: { enabled: true, effort: 'high', max_tokens: 500 },
user: 'user-456'
})
})
it('replaces arrays instead of merging them', () => {
const first = createOpenRouterOptions({ plugins: [{ id: 'old' }] })
const second = createOpenRouterOptions({ plugins: [{ id: 'new' }] })
const first = opts({ openrouter: { plugins: [{ id: 'old' }] } })
const second = opts({ openrouter: { plugins: [{ id: 'new' }] } })
const merged = mergeProviderOptions(first, second)
// @ts-expect-error type-check for openrouter options is skipped. see function signature of createOpenRouterOptions
expect(merged.openrouter?.plugins).toEqual([{ id: 'new' }])
expect((merged.openrouter as Record<string, unknown>)?.plugins).toEqual([{ id: 'new' }])
})
})

View File

@@ -1,30 +1,4 @@
import type { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
/**
* 创建特定供应商的选项
* @param provider 供应商名称
* @param options 供应商特定的选项
* @returns 格式化的provider options
*/
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
provider: T,
options: ExtractProviderOptions<T>
): Record<T, ExtractProviderOptions<T>> {
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
}
/**
* 创建任意供应商的选项(包括未知供应商)
* @param provider 供应商名称
* @param options 供应商选项
* @returns 格式化的provider options
*/
export function createGenericProviderOptions<T extends string>(
provider: T,
options: Record<string, any>
): Record<T, Record<string, any>> {
return { [provider]: options } as Record<T, Record<string, any>>
}
import type { TypedProviderOptions } from './types'
type PlainObject = Record<string, any>
@@ -86,31 +60,3 @@ export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions
return acc
}, {} as TypedProviderOptions)
}
/**
* 创建OpenAI供应商选项的便捷函数
*/
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
return createProviderOptions('openai', options)
}
/**
* 创建Anthropic供应商选项的便捷函数
*/
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
return createProviderOptions('anthropic', options)
}
/**
* 创建Google供应商选项的便捷函数
*/
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
return createProviderOptions('google', options)
}
/**
* 创建OpenRouter供应商选项的便捷函数
*/
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'> | Record<string, any>) {
return createProviderOptions('openrouter', options)
}

View File

@@ -5,12 +5,11 @@ import { type SharedV3ProviderMetadata } from '@ai-sdk/provider'
import { type XaiProviderOptions } from '@ai-sdk/xai'
import { type OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
export type ProviderOptions<T extends keyof SharedV3ProviderMetadata> = SharedV3ProviderMetadata[T]
/**
* 供应商选项类型如果map中没有说明没有约束
* Known provider options map for type-safe providerOptions construction.
* Providers not listed here accept arbitrary Record<string, any>.
*/
export type ProviderOptionsMap = {
type ProviderOptionsMap = {
openai: OpenAIResponsesProviderOptions
anthropic: AnthropicProviderOptions
google: GoogleGenerativeAIProviderOptions
@@ -18,12 +17,9 @@ export type ProviderOptionsMap = {
xai: XaiProviderOptions
}
// 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
/**
* 类型安全的ProviderOptions
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
* Type-safe ProviderOptions.
* Known providers use strict types; unknown providers allow Record<string, any>.
*/
export type TypedProviderOptions = {
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]

View File

@@ -1,257 +0,0 @@
# AI Core 插件系统
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**
## 🎯 设计理念
- **语义清晰**:不同钩子有不同的执行语义
- **类型安全**TypeScript 完整支持
- **性能优化**First 短路、Parallel 并发、Sequential 链式
- **易于扩展**`enforce` 排序 + 功能分类
## 📋 钩子类型
### 🥇 First 钩子 - 首个有效结果
```typescript
// 只执行第一个返回值的插件,用于解析和查找
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
```
### 🔄 Sequential 钩子 - 链式数据转换
```typescript
// 按顺序链式执行,每个插件可以修改数据
transformParams?: (params: any, context: AiRequestContext) => any
transformResult?: (result: any, context: AiRequestContext) => any
```
### ⚡ Parallel 钩子 - 并行副作用
```typescript
// 并发执行,用于日志、监控等副作用
onRequestStart?: (context: AiRequestContext) => void
onRequestEnd?: (context: AiRequestContext, result: any) => void
onError?: (error: Error, context: AiRequestContext) => void
```
### 🌊 Stream 钩子 - 流处理
```typescript
// 直接使用 AI SDK 的 TransformStream
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
```
## 🚀 快速开始
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
// 添加插件
pluginManager.use({
name: 'my-plugin',
async transformParams(params, context) {
return { ...params, temperature: 0.7 }
}
})
// 使用插件
const context = createContext('openai', 'gpt-4', { messages: [] })
const transformedParams = await pluginManager.executeSequential(
'transformParams',
{ messages: [{ role: 'user', content: 'Hello' }] },
context
)
```
### 完整示例
```typescript
import {
PluginManager,
ModelAliasPlugin,
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([
ModelAliasPlugin, // 模型别名解析
ParamsValidationPlugin, // 参数验证
LoggingPlugin // 日志记录
])
// AI 请求流程
async function aiRequest(providerId: string, modelId: string, params: any) {
const context = createContext(providerId, modelId, params)
try {
// 1. 【并行】触发请求开始事件
await manager.executeParallel('onRequestStart', context)
// 2. 【首个】解析模型别名
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
context.modelId = resolvedModel || modelId
// 3. 【串行】转换请求参数
const transformedParams = await manager.executeSequential('transformParams', params, context)
// 4. 【流处理】收集流转换器AI SDK 原生支持数组)
const streamTransforms = manager.collectStreamTransforms()
// 5. 调用 AI SDK这里省略具体实现
const result = await callAiSdk(transformedParams, streamTransforms)
// 6. 【串行】转换响应结果
const transformedResult = await manager.executeSequential('transformResult', result, context)
// 7. 【并行】触发请求完成事件
await manager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 8. 【并行】触发错误事件
await manager.executeParallel('onError', context, undefined, error)
throw error
}
}
```
## 🔧 自定义插件
### 模型别名插件
```typescript
const ModelAliasPlugin = definePlugin({
name: 'model-alias',
enforce: 'pre', // 最先执行
async resolveModel(modelId) {
const aliases = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229'
}
return aliases[modelId] || null
}
})
```
### 参数验证插件
```typescript
const ValidationPlugin = definePlugin({
name: 'validation',
async transformParams(params) {
if (!params.messages) {
throw new Error('messages is required')
}
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096
}
}
})
```
### 监控插件
```typescript
const MonitoringPlugin = definePlugin({
name: 'monitoring',
enforce: 'post', // 最后执行
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`请求耗时: ${duration}ms`)
}
})
```
### 内容过滤插件
```typescript
const FilterPlugin = definePlugin({
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
controller.enqueue({ ...chunk, textDelta: filtered })
} else {
controller.enqueue(chunk)
}
}
})
}
})
```
## 📊 执行顺序
### 插件排序
```
enforce: 'pre' → normal → enforce: 'post'
```
### 钩子执行流程
```mermaid
graph TD
A[请求开始] --> B[onRequestStart 并行执行]
B --> C[resolveModel 首个有效]
C --> D[loadTemplate 首个有效]
D --> E[transformParams 串行执行]
E --> F[collectStreamTransforms]
F --> G[AI SDK 调用]
G --> H[transformResult 串行执行]
H --> I[onRequestEnd 并行执行]
G --> J[异常处理]
J --> K[onError 并行执行]
```
## 💡 最佳实践
1. **功能单一**:每个插件专注一个功能
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
3. **错误处理**:插件内部处理异常,不要让异常向上传播
4. **性能优化**使用合适的钩子类型First vs Sequential vs Parallel
5. **命名规范**:使用语义化的插件名称
## 🔍 调试工具
```typescript
// 查看插件统计信息
const stats = manager.getStats()
console.log('插件统计:', stats)
// 查看所有插件
const plugins = manager.getPlugins()
console.log(
'已注册插件:',
plugins.map((p) => p.name)
)
```
## ⚡ 性能优势
- **First 钩子**:一旦找到结果立即停止,避免无效计算
- **Parallel 钩子**:真正并发执行,不阻塞主流程
- **Sequential 钩子**:保证数据转换的顺序性
- **Stream 钩子**:直接集成 AI SDK零开销
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。

View File

@@ -1,51 +0,0 @@
import { google } from '@ai-sdk/google'
import type { ToolSet } from 'ai'
import { type AiPlugin, definePlugin, type StreamTextParams, type StreamTextResult } from '../../'
const toolNameMap = {
googleSearch: 'google_search',
urlContext: 'url_context',
codeExecution: 'code_execution'
} as const
type ToolConfigKey = keyof typeof toolNameMap
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
export const googleToolsPlugin = (config?: ToolConfig): AiPlugin<StreamTextParams, StreamTextResult> =>
definePlugin<StreamTextParams, StreamTextResult>({
name: 'googleToolsPlugin',
transformParams: (params, context): Partial<StreamTextParams> => {
const { providerId } = context
// 只在 Google provider 且有配置时才修改参数
if (providerId !== 'google' || !config) {
return {} // 返回空 Partial表示不修改
}
if (typeof params !== 'object' || params === null) {
return {}
}
// 构建 tools 对象,确保类型兼容
const hasTools = (Object.keys(config) as ToolConfigKey[]).some(
(key) => config[key] && key in toolNameMap && key in google.tools
)
if (!hasTools) {
return {} // 返回空 Partial表示不修改
}
// 构建符合 AI SDK 的 tools 对象
const tools: ToolSet = {}
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
if (config[key] && key in toolNameMap && key in google.tools) {
const toolName = toolNameMap[key]
tools[toolName] = google.tools[key]({}) as ToolSet[string]
}
})
return { tools: { ...params.tools, ...tools } }
}
})

View File

@@ -1,10 +1,4 @@
/**
* 内置插件命名空间
* 所有内置插件都以 'built-in:' 为前缀
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export * from './googleToolsPlugin'
export * from './providerToolPlugin'
export * from './toolUsePlugin/promptToolUsePlugin'
export * from './toolUsePlugin/type'
export * from './webSearchPlugin'

View File

@@ -1,86 +0,0 @@
/**
* 内置插件:日志记录
* 记录AI调用的关键信息支持性能监控和调试
*/
import { definePlugin } from '../index'
import type { AiRequestContext } from '../types'
export interface LoggingConfig {
// 日志级别
level?: 'debug' | 'info' | 'warn' | 'error'
// 是否记录参数
logParams?: boolean
// 是否记录结果
logResult?: boolean
// 是否记录性能数据
logPerformance?: boolean
// 自定义日志函数
logger?: (level: string, message: string, data?: any) => void
}
/**
* 创建日志插件
*/
export function createLoggingPlugin(config: LoggingConfig = {}) {
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
const startTimes = new Map<string, number>()
return definePlugin({
name: 'built-in:logging',
onRequestStart: (context: AiRequestContext) => {
const requestId = context.requestId
startTimes.set(requestId, Date.now())
logger(level, `🚀 AI Request Started`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
originalParams: logParams ? context.originalParams : '[hidden]'
})
},
onRequestEnd: (context: AiRequestContext, result: any) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
const logData: any = {
requestId,
providerId: context.providerId,
modelId: context.modelId
}
if (logPerformance && duration) {
logData.duration = `${duration}ms`
}
if (logResult) {
logData.result = result
}
logger(level, `✅ AI Request Completed`, logData)
},
onError: (error: Error, context: AiRequestContext) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
logger('error', `❌ AI Request Failed`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
duration: duration ? `${duration}ms` : undefined,
error: {
name: error.name,
message: error.message,
stack: error.stack
}
})
}
})
}

View File

@@ -0,0 +1,40 @@
/**
* 通用 provider 工具注入插件
*
* 查找 extensionRegistry 中声明的 toolFactory
* 将返回的 ToolFactoryPatchtools / providerOptions合并到 params。
*/
import { mergeProviderOptions } from '../../options'
import { extensionRegistry } from '../../providers'
import type { ToolCapability } from '../../providers/types/toolFactory'
import { definePlugin } from '../'
export const providerToolPlugin = (capability: ToolCapability, config: Record<string, any> = {}) =>
definePlugin({
name: capability,
enforce: 'pre',
transformParams: async (params: any, context) => {
const { providerId } = context
const modelProvider =
context.model && typeof context.model !== 'string' && 'provider' in context.model
? context.model.provider
: undefined
const resolved = await extensionRegistry.resolveToolCapability(providerId, capability, modelProvider)
if (!resolved) return params
const userConfig = config[providerId] ?? {}
const patch = resolved.factory(resolved.provider)(userConfig)
if (patch.tools) {
params.tools = { ...params.tools, ...patch.tools }
}
if (patch.providerOptions) {
params.providerOptions = mergeProviderOptions(params.providerOptions, patch.providerOptions)
}
return params
}
})

View File

@@ -1,4 +1,5 @@
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
import { createMockContext, createMockTool } from '@test-utils'
import type {
EmbeddingModelUsage,
ImageModelUsage,
@@ -10,7 +11,6 @@ import type {
import { simulateReadableStream } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockTool } from '../../../../../__tests__'
import { StreamEventManager } from '../StreamEventManager'
import type { StreamController } from '../ToolExecutor'

View File

@@ -1,9 +1,9 @@
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils'
import type { TextStreamPart, ToolSet } from 'ai'
import { simulateReadableStream } from 'ai'
import { convertReadableStreamToArray } from 'ai/test'
import { describe, expect, it, vi } from 'vitest'
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__'
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
describe('promptToolUsePlugin', () => {

View File

@@ -348,7 +348,7 @@ export const createPromptToolUsePlugin = (
return new TransformStream()
}
// 从 context 中获取或初始化 usage 累加器
// 初始化 usage 累加器和工具执行状态
if (!context.accumulatedUsage) {
context.accumulatedUsage = {
inputTokens: 0,
@@ -358,17 +358,15 @@ export const createPromptToolUsePlugin = (
cachedInputTokens: 0
}
}
if (context.hasExecutedToolsInCurrentStep === undefined) {
context.hasExecutedToolsInCurrentStep = false
}
// 创建工具执行器、流事件管理器和标签提取器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
// 在context中初始化工具执行状态避免递归调用时状态丢失
if (!context.hasExecutedToolsInCurrentStep) {
context.hasExecutedToolsInCurrentStep = false
}
// 用于hold text-start事件直到确认有非工具标签内容
let pendingTextStart: TextStreamPart<TOOLS> | null = null
let hasStartedText = false

View File

@@ -1,189 +0,0 @@
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { xai } from '@ai-sdk/xai'
import type { InferToolInput, InferToolOutput } from 'ai'
import { type Tool } from 'ai'
import { createOpenRouterOptions, mergeProviderOptions } from '../../../options'
import type { AiRequestContext } from '../../'
import type { OpenRouterSearchConfig } from './openrouter'
/**
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
*/
export type OpenAISearchConfig = NonNullable<Parameters<typeof openai.tools.webSearch>[0]>
export type OpenAISearchPreviewConfig = NonNullable<Parameters<typeof openai.tools.webSearchPreview>[0]>
export type AnthropicSearchConfig = NonNullable<Parameters<typeof anthropic.tools.webSearch_20250305>[0]>
export type GoogleSearchConfig = NonNullable<Parameters<typeof google.tools.googleSearch>[0]>
export type XAIWebSearchConfig = NonNullable<Parameters<typeof xai.tools.webSearch>[0]>
export type XAIXSearchConfig = NonNullable<Parameters<typeof xai.tools.xSearch>[0]>
type NormalizeTool<T> = T extends Tool<infer INPUT, infer OUTPUT> ? Tool<INPUT, OUTPUT> : Tool<any, any>
type AnthropicWebSearchTool = NormalizeTool<ReturnType<typeof anthropic.tools.webSearch_20250305>>
type OpenAIWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearch>>
type OpenAIChatWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearchPreview>>
type GoogleWebSearchTool = NormalizeTool<ReturnType<typeof google.tools.googleSearch>>
type XAIWebSearchTool = NormalizeTool<ReturnType<typeof xai.tools.webSearch>>
type XAIXSearchTool = NormalizeTool<ReturnType<typeof xai.tools.xSearch>>
/**
* 插件初始化时接收的完整配置对象
*
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
*/
export interface WebSearchPluginConfig {
openai?: OpenAISearchConfig
'openai-chat'?: OpenAISearchPreviewConfig
anthropic?: AnthropicSearchConfig
xai?: XAIWebSearchConfig
'xai-xsearch'?: XAIXSearchConfig
google?: GoogleSearchConfig
openrouter?: OpenRouterSearchConfig
}
/**
* 插件的默认配置
*/
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {},
openai: {},
'openai-chat': {},
xai: {
enableImageUnderstanding: true
},
'xai-xsearch': {
enableImageUnderstanding: true
},
anthropic: {
maxUses: 5
},
openrouter: {
plugins: [
{
id: 'web',
max_results: 5
}
]
}
}
export type WebSearchToolOutputSchema = {
// Anthropic 工具 - 手动定义
anthropic: InferToolOutput<AnthropicWebSearchTool>
// OpenAI 工具 - 基于实际输出
// TODO: 上游定义不规范,是unknown
// openai: InferToolOutput<ReturnType<typeof openai.tools.webSearch>>
openai: {
status: 'completed' | 'failed'
}
'openai-chat': {
status: 'completed' | 'failed'
}
// Google 工具
// TODO: 上游定义不规范,是unknown
// google: InferToolOutput<ReturnType<typeof google.tools.googleSearch>>
google: {
webSearchQueries?: string[]
groundingChunks?: Array<{
web?: { uri: string; title: string }
}>
}
// xAI 工具
xai: InferToolOutput<XAIWebSearchTool>
'xai-xsearch': InferToolOutput<XAIXSearchTool>
}
export type WebSearchToolInputSchema = {
anthropic: InferToolInput<AnthropicWebSearchTool>
openai: InferToolInput<OpenAIWebSearchTool>
google: InferToolInput<GoogleWebSearchTool>
'openai-chat': InferToolInput<OpenAIChatWebSearchTool>
xai: InferToolInput<XAIWebSearchTool>
'xai-xsearch': InferToolInput<XAIXSearchTool>
}
/**
* Helper function to ensure params.tools object exists
*/
const ensureToolsObject = (params: any) => {
if (!params.tools) params.tools = {}
}
/**
* Helper function to apply tool-based web search configuration
*/
const applyToolBasedSearch = (params: any, toolName: string, toolInstance: any) => {
ensureToolsObject(params)
params.tools[toolName] = toolInstance
}
/**
* Helper function to apply provider options-based web search configuration
*/
const applyProviderOptionsSearch = (params: any, searchOptions: any) => {
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any, context?: AiRequestContext) => {
const providerId = context?.providerId
// Provider-specific configuration map
const providerHandlers: Record<string, () => void> = {
openai: () => {
const cfg = config.openai ?? DEFAULT_WEB_SEARCH_CONFIG.openai
applyToolBasedSearch(params, 'web_search', openai.tools.webSearch(cfg))
},
'openai-chat': () => {
const cfg = (config['openai-chat'] ?? DEFAULT_WEB_SEARCH_CONFIG['openai-chat']) as OpenAISearchPreviewConfig
applyToolBasedSearch(params, 'web_search_preview', openai.tools.webSearchPreview(cfg))
},
anthropic: () => {
const cfg = config.anthropic ?? DEFAULT_WEB_SEARCH_CONFIG.anthropic
applyToolBasedSearch(params, 'web_search', anthropic.tools.webSearch_20250305(cfg))
},
google: () => {
const cfg = (config.google ?? DEFAULT_WEB_SEARCH_CONFIG.google) as GoogleSearchConfig
applyToolBasedSearch(params, 'web_search', google.tools.googleSearch(cfg))
},
xai: () => {
const cfg = config.xai ?? DEFAULT_WEB_SEARCH_CONFIG.xai
applyToolBasedSearch(params, 'web_search', xai.tools.webSearch(cfg))
const xSearchCfg = config['xai-xsearch'] ?? DEFAULT_WEB_SEARCH_CONFIG['xai-xsearch']
applyToolBasedSearch(params, 'x_search', xai.tools.xSearch(xSearchCfg))
},
openrouter: () => {
const cfg = (config.openrouter ?? DEFAULT_WEB_SEARCH_CONFIG.openrouter) as OpenRouterSearchConfig
const searchOptions = createOpenRouterOptions(cfg)
applyProviderOptionsSearch(params, searchOptions)
}
}
// Try provider-specific handler first
const handler = providerId && providerHandlers[providerId]
if (handler) {
handler()
return params
}
// Fallback: apply based on available config keys (prioritized order)
const fallbackOrder: Array<keyof WebSearchPluginConfig> = [
'openai',
'openai-chat',
'anthropic',
'google',
'xai',
'openrouter'
]
for (const key of fallbackOrder) {
if (config[key]) {
providerHandlers[key]()
break
}
}
return params
}

View File

@@ -1,44 +1,40 @@
/**
* Web Search Plugin
* 提供统一的网络搜索能力,支持多个 AI Provider
*/
import type { WebSearchToolConfigMap } from '../../../providers'
import { definePlugin } from '../../'
import type { WebSearchPluginConfig } from './helper'
import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper'
export type OpenRouterSearchConfig = {
plugins?: Array<{
id: 'web'
/**
* Maximum number of search results to include (default: 5)
*/
max_results?: number
/**
* Custom search prompt to guide the search query
*/
search_prompt?: string
}>
/**
* Built-in web search options for models that support native web search
*/
web_search_options?: {
/**
* Maximum number of search results to include
*/
max_results?: number
/**
* Custom search prompt to guide the search query
*/
search_prompt?: string
}
}
/**
* 网络搜索插件
* 插件初始化时接收的完整配置对象
*
* @param config - 在插件初始化时传入的静态配置
* key = provider IDvalue = 该 provider 的搜索配置
*
* - 大部分类型从 coreExtensions 的 toolFactories 声明中自动提取WebSearchToolConfigMap
* - OpenRouter 使用自定义配置(非 SDK .tools 模式),从 openrouter.ts 导入
*/
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
definePlugin({
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context) => {
let { providerId } = context
// For cherryin providers, extract the actual provider from the model's provider string
// Expected format: "cherryin.{actualProvider}" (e.g., "cherryin.gemini")
if (providerId === 'cherryin' || providerId === 'cherryin-chat') {
const provider = params.model?.provider
if (provider && typeof provider === 'string' && provider.includes('.')) {
const extractedProviderId = provider.split('.')[1]
if (extractedProviderId) {
providerId = extractedProviderId
}
}
}
switchWebSearchTool(config, params, { ...context, providerId })
return params
}
})
// 导出类型定义供开发者使用
export * from './helper'
// 默认导出
export default webSearchPlugin
export type WebSearchPluginConfig = WebSearchToolConfigMap & {
openrouter?: OpenRouterSearchConfig
}

View File

@@ -1,26 +0,0 @@
export type OpenRouterSearchConfig = {
plugins?: Array<{
id: 'web'
/**
* Maximum number of search results to include (default: 5)
*/
max_results?: number
/**
* Custom search prompt to guide the search query
*/
search_prompt?: string
}>
/**
* Built-in web search options for models that support native web search
*/
web_search_options?: {
/**
* Maximum number of search results to include
*/
max_results?: number
/**
* Custom search prompt to guide the search query
*/
search_prompt?: string
}
}

View File

@@ -2,12 +2,8 @@
export type {
AiPlugin,
AiRequestContext,
AiRequestMetadata,
GenerateTextParams,
GenerateTextResult,
HookResult,
PluginManagerConfig,
RecursiveCallFn,
StreamTextParams,
StreamTextResult
} from './types'

View File

@@ -23,7 +23,6 @@ export interface AiRequestMetadata {
enableGenerateImage?: boolean
isPromptToolUse?: boolean
isSupportedToolUse?: boolean
isImageGenerationEndpoint?: boolean
// 自定义元数据,使用 JSONValue 确保类型安全
custom?: JSONObject
}
@@ -32,7 +31,7 @@ export interface AiRequestMetadata {
* 递归调用函数类型
* 泛型化以保持类型推导
*/
export type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams: Partial<TParams>) => Promise<TResult>
type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams: Partial<TParams>) => Promise<TResult>
/**
* AI 请求上下文
@@ -107,20 +106,3 @@ export interface AiPlugin<TParams = unknown, TResult = unknown> {
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
}
/**
* 插件管理器配置
*/
export interface PluginManagerConfig<TParams = unknown, TResult = unknown> {
plugins: AiPlugin<TParams, TResult>[]
context: Partial<AiRequestContext<TParams, TResult>>
}
/**
* 钩子执行结果
* 泛型参数指定返回值类型
*/
export interface HookResult<T = unknown> {
value: T
stop?: boolean
}

View File

@@ -1,143 +0,0 @@
/**
* Hub Provider - 支持路由到多个底层provider
*
* 支持格式: hubId:providerId:modelId
* 例如: aihubmix:anthropic:claude-3.5-sonnet
*/
import type {
EmbeddingModelV3,
ImageModelV3,
LanguageModelV3,
ProviderV3,
RerankingModelV3,
SpeechModelV3,
TranscriptionModelV3
} from '@ai-sdk/provider'
import { customProvider, wrapProvider } from 'ai'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from './RegistryManagement'
import type { AiSdkProvider } from './types'
export interface HubProviderConfig {
/** Hub的唯一标识符 */
hubId: string
/** 是否启用调试日志 */
debug?: boolean
}
export class HubProviderError extends Error {
constructor(
message: string,
public readonly hubId: string,
public readonly providerId?: string,
public readonly originalError?: Error
) {
super(message)
this.name = 'HubProviderError'
}
}
/**
* 解析Hub模型ID
*/
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
const parts = modelId.split(DEFAULT_SEPARATOR)
if (parts.length !== 2) {
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
}
return {
provider: parts[0],
actualModelId: parts[1]
}
}
/**
* 创建Hub Provider
*/
export function createHubProvider(config: HubProviderConfig): AiSdkProvider {
const { hubId } = config
function getTargetProvider(providerId: string): ProviderV3 {
// 从全局注册表获取provider实例
try {
const provider = globalRegistryManagement.getProvider(providerId)
if (!provider) {
throw new HubProviderError(
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
hubId,
providerId
)
}
// 使用 wrapProvider 确保返回的是 V3 provider
// 这样可以自动处理 V2 provider 到 V3 的转换
return wrapProvider({ provider, languageModelMiddleware: [] })
} catch (error) {
throw new HubProviderError(
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
hubId,
providerId,
error instanceof Error ? error : undefined
)
}
}
// 创建符合 ProviderV3 规范的 fallback provider
const hubFallbackProvider = {
specificationVersion: 'v3' as const,
languageModel: (modelId: string): LanguageModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
return targetProvider.languageModel(actualModelId)
},
embeddingModel: (modelId: string): EmbeddingModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
return targetProvider.embeddingModel(actualModelId)
},
imageModel: (modelId: string): ImageModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
return targetProvider.imageModel(actualModelId)
},
transcriptionModel: (modelId: string): TranscriptionModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.transcriptionModel) {
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
}
return targetProvider.transcriptionModel(actualModelId)
},
speechModel: (modelId: string): SpeechModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.speechModel) {
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
}
return targetProvider.speechModel(actualModelId)
},
rerankingModel: (modelId: string): RerankingModelV3 => {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.rerankingModel) {
throw new HubProviderError(`Provider "${provider}" does not support reranking models`, hubId, provider)
}
return targetProvider.rerankingModel(actualModelId)
}
}
return customProvider({
fallbackProvider: hubFallbackProvider
})
}

View File

@@ -1,219 +0,0 @@
/**
* Provider 注册表管理器
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
* 基于 AI SDK 原生的 createProviderRegistry
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
type PROVIDERS = Record<string, ProviderV3>
export const DEFAULT_SEPARATOR = '|'
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
private providers: PROVIDERS = {}
private aliases: Set<string> = new Set() // 记录哪些key是别名
private separator: SEPARATOR
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
this.separator = options.separator
}
/**
* 注册已配置好的 provider 实例
*/
registerProvider(id: string, provider: ProviderV3, aliases?: string[]): this {
// 注册主provider
this.providers[id] = provider
// 注册别名都指向同一个provider实例
if (aliases) {
aliases.forEach((alias) => {
this.providers[alias] = provider // 直接存储引用
this.aliases.add(alias) // 标记为别名
})
}
this.rebuildRegistry()
return this
}
/**
* 获取已注册的provider实例
*/
getProvider(id: string): ProviderV3 | undefined {
return this.providers[id]
}
/**
* 批量注册 providers
*/
registerProviders(providers: Record<string, ProviderV3>): this {
Object.assign(this.providers, providers)
this.rebuildRegistry()
return this
}
/**
* 移除 provider同时清理相关别名
*/
unregisterProvider(id: string): this {
const provider = this.providers[id]
if (!provider) return this
// 如果移除的是真实ID需要清理所有指向它的别名
if (!this.aliases.has(id)) {
// 找到所有指向此provider的别名并删除
const aliasesToRemove: string[] = []
this.aliases.forEach((alias) => {
if (this.providers[alias] === provider) {
aliasesToRemove.push(alias)
}
})
aliasesToRemove.forEach((alias) => {
delete this.providers[alias]
this.aliases.delete(alias)
})
} else {
// 如果移除的是别名,只删除别名记录
this.aliases.delete(id)
}
delete this.providers[id]
this.rebuildRegistry()
return this
}
/**
* 立即重建 registry - 每次变更都重建
*/
private rebuildRegistry(): void {
if (Object.keys(this.providers).length === 0) {
this.registry = null
return
}
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
separator: this.separator
})
}
/**
* 获取语言模型 - AI SDK 原生方法
*/
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV3 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.languageModel(id)
}
/**
* 获取文本嵌入模型 - AI SDK 原生方法
*/
embeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV3 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.embeddingModel(id)
}
/**
* 获取图像模型 - AI SDK 原生方法
*/
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV3 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.imageModel(id)
}
/**
* 获取转录模型 - AI SDK 原生方法
*/
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.transcriptionModel(id)
}
/**
* 获取语音模型 - AI SDK 原生方法
*/
speechModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.speechModel(id)
}
/**
* 获取已注册的 provider 列表
*/
getRegisteredProviders(): string[] {
return Object.keys(this.providers)
}
/**
* 检查是否有已注册的 providers
*/
hasProviders(): boolean {
return Object.keys(this.providers).length > 0
}
/**
* 清除所有 providers
*/
clear(): this {
this.providers = {}
this.aliases.clear()
this.registry = null
return this
}
/**
* 解析真实的Provider ID供getAiSdkProviderId使用
* 如果传入的是别名返回真实的Provider ID
* 如果传入的是真实ID直接返回
*/
resolveProviderId(id: string): string {
if (!this.aliases.has(id)) return id // 不是别名,直接返回
// 是别名找到真实ID
const targetProvider = this.providers[id]
for (const [realId, provider] of Object.entries(this.providers)) {
if (provider === targetProvider && !this.aliases.has(realId)) {
return realId
}
}
return id
}
/**
* 检查是否为别名
*/
isAlias(id: string): boolean {
return this.aliases.has(id)
}
/**
* 获取所有别名映射关系
*/
getAllAliases(): Record<string, string> {
const result: Record<string, string> = {}
this.aliases.forEach((alias) => {
result[alias] = this.resolveProviderId(alias)
})
return result
}
}
/**
* 全局注册表管理器实例
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
*/
export const globalRegistryManagement = new RegistryManagement()

View File

@@ -0,0 +1,925 @@
/**
* ExtensionRegistry 单元测试
*/
import type { ProviderV3 } from '@ai-sdk/provider'
import { createMockProviderV3 } from '@test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ExtensionRegistry } from '../core/ExtensionRegistry'
import { ProviderExtension } from '../core/ProviderExtension'
import { ProviderCreationError } from '../core/utils'
describe('ExtensionRegistry', () => {
let registry: ExtensionRegistry
beforeEach(() => {
registry = new ExtensionRegistry()
})
describe('register', () => {
it('should register an extension', () => {
const extension = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
registry.register(extension)
expect(registry.has('test-provider')).toBe(true)
expect(registry.get('test-provider')).toBe(extension)
})
it('should register aliases', () => {
const extension = new ProviderExtension({
name: 'openrouter',
aliases: ['or', 'open-router'],
create: createMockProviderV3
})
registry.register(extension)
expect(registry.has('openrouter')).toBe(true)
expect(registry.has('or')).toBe(true)
expect(registry.has('open-router')).toBe(true)
// 别名应该指向同一个 extension
expect(registry.get('or')).toBe(extension)
expect(registry.get('open-router')).toBe(extension)
})
it('should be idempotent when name already registered', () => {
const ext1 = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
const ext2 = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
registry.register(ext1)
registry.register(ext2) // should not throw
// original extension is preserved
expect(registry.get('test-provider')).toBe(ext1)
})
it('should throw error if alias already registered', () => {
const ext1 = new ProviderExtension({
name: 'provider1',
aliases: ['shared-alias'],
create: createMockProviderV3
})
const ext2 = new ProviderExtension({
name: 'provider2',
aliases: ['shared-alias'],
create: createMockProviderV3
})
registry.register(ext1)
expect(() => registry.register(ext2)).toThrow('already registered')
})
it('should support method chaining', () => {
const ext1 = new ProviderExtension({
name: 'provider1',
create: createMockProviderV3
})
const ext2 = new ProviderExtension({
name: 'provider2',
create: createMockProviderV3
})
const result = registry.register(ext1).register(ext2)
expect(result).toBe(registry)
expect(registry.has('provider1')).toBe(true)
expect(registry.has('provider2')).toBe(true)
})
})
describe('registerAll', () => {
it('should register multiple extensions', () => {
const extensions = [
new ProviderExtension({ name: 'provider1', create: createMockProviderV3 }),
new ProviderExtension({ name: 'provider2', create: createMockProviderV3 }),
new ProviderExtension({ name: 'provider3', create: createMockProviderV3 })
]
registry.registerAll(extensions)
expect(registry.has('provider1')).toBe(true)
expect(registry.has('provider2')).toBe(true)
expect(registry.has('provider3')).toBe(true)
})
it('should support method chaining', () => {
const extensions = [new ProviderExtension({ name: 'test', create: createMockProviderV3 })]
const result = registry.registerAll(extensions)
expect(result).toBe(registry)
})
})
describe('unregister', () => {
it('should remove extension', () => {
const extension = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
registry.register(extension)
expect(registry.has('test-provider')).toBe(true)
const result = registry.unregister('test-provider')
expect(result).toBe(true)
expect(registry.has('test-provider')).toBe(false)
})
it('should remove aliases', () => {
const extension = new ProviderExtension({
name: 'test-provider',
aliases: ['alias1', 'alias2'],
create: createMockProviderV3
})
registry.register(extension)
registry.unregister('test-provider')
expect(registry.has('alias1')).toBe(false)
expect(registry.has('alias2')).toBe(false)
})
it('should return false if extension not found', () => {
const result = registry.unregister('non-existent')
expect(result).toBe(false)
})
})
describe('get', () => {
it('should get extension by name', () => {
const extension = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
registry.register(extension)
expect(registry.get('test-provider')).toBe(extension)
})
it('should get extension by alias', () => {
const extension = new ProviderExtension({
name: 'test-provider',
aliases: ['test-alias'],
create: createMockProviderV3
})
registry.register(extension)
expect(registry.get('test-alias')).toBe(extension)
})
it('should return undefined for non-existent ID', () => {
expect(registry.get('non-existent')).toBeUndefined()
})
})
describe('getAll', () => {
it('should return all registered extensions', () => {
const ext1 = new ProviderExtension({ name: 'provider1', create: createMockProviderV3 })
const ext2 = new ProviderExtension({ name: 'provider2', create: createMockProviderV3 })
registry.register(ext1).register(ext2)
const all = registry.getAll()
expect(all).toHaveLength(2)
expect(all).toContain(ext1)
expect(all).toContain(ext2)
})
it('should return empty array when no extensions registered', () => {
expect(registry.getAll()).toEqual([])
})
})
describe('getAllProviderIds', () => {
it('should return all provider IDs including variants', () => {
const ext1 = new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Chat',
transform: (provider: ProviderV3) => provider
}
]
})
const ext2 = new ProviderExtension({
name: 'azure',
create: createMockProviderV3
})
registry.register(ext1).register(ext2)
const ids = registry.getAllProviderIds()
expect(ids).toContain('openai')
expect(ids).toContain('oai')
expect(ids).toContain('openai-chat')
expect(ids).toContain('azure')
})
})
describe('clear', () => {
it('should remove all extensions', () => {
registry.register(new ProviderExtension({ name: 'provider1', create: createMockProviderV3 }))
registry.register(new ProviderExtension({ name: 'provider2', create: createMockProviderV3 }))
registry.clear()
expect(registry.getAll()).toEqual([])
expect(registry.getAllProviderIds()).toEqual([])
})
})
describe('createProvider', () => {
it('should create provider using create function', async () => {
const mockProvider = createMockProviderV3()
const extension = new ProviderExtension({
name: 'test-provider',
create: () => mockProvider
})
registry.register(extension)
const provider = await registry.createProvider('test-provider')
expect(provider).toBe(mockProvider)
})
it('should merge default options with user settings', async () => {
let receivedSettings: any
const extension = new ProviderExtension<any>({
name: 'test-provider',
defaultOptions: { apiKey: 'default-key', timeout: 5000 },
create: ((settings: any) => {
receivedSettings = settings
return createMockProviderV3()
}) as any
})
registry.register(extension)
await registry.createProvider('test-provider', { baseURL: 'https://api.test.com' })
expect(receivedSettings).toEqual({
apiKey: 'default-key',
timeout: 5000,
baseURL: 'https://api.test.com'
})
})
it('should create provider using dynamic import', async () => {
const mockProvider = createMockProviderV3()
const extension = new ProviderExtension({
name: 'lazy-provider',
import: async () => ({
createLazyProvider: () => mockProvider
}),
creatorFunctionName: 'createLazyProvider'
})
registry.register(extension)
const provider = await registry.createProvider('lazy-provider')
expect(provider).toBe(mockProvider)
})
it('should throw error if extension not found', async () => {
await expect(registry.createProvider('non-existent')).rejects.toThrow('not found')
})
it('should throw error if creator function not found in imported module', async () => {
const extension = new ProviderExtension({
name: 'bad-import',
import: async () => ({}),
creatorFunctionName: 'nonExistentFunction'
})
registry.register(extension)
try {
await registry.createProvider('bad-import')
expect.fail('Should have thrown')
} catch (error) {
expect(error).toBeInstanceOf(ProviderCreationError)
expect((error as ProviderCreationError).cause.message).toContain('not found in imported module')
}
})
})
describe('Provider Caching', () => {
it('should cache provider instances based on settings', async () => {
const createSpy = vi.fn(createMockProviderV3)
registry.register(
new ProviderExtension({
name: 'test-provider',
create: createSpy
})
)
// First call - should create
const provider1 = await registry.createProvider('test-provider', { apiKey: 'same-key' })
expect(createSpy).toHaveBeenCalledTimes(1)
// Second call with same settings - should use cache
const provider2 = await registry.createProvider('test-provider', { apiKey: 'same-key' })
expect(createSpy).toHaveBeenCalledTimes(1) // Still 1
expect(provider2).toBe(provider1) // Same instance
// Third call with different settings - should create new
const provider3 = await registry.createProvider('test-provider', { apiKey: 'different-key' })
expect(createSpy).toHaveBeenCalledTimes(2)
expect(provider3).not.toBe(provider1)
})
it('should deep merge settings before generating cache key', async () => {
let firstSettings: any
let secondSettings: any
const extension = new ProviderExtension({
name: 'test-provider',
defaultOptions: {
apiKey: 'default-key',
headers: { 'X-Default': 'value' }
},
create: (settings) => {
if (!firstSettings) {
firstSettings = settings
} else {
secondSettings = settings
}
return createMockProviderV3()
}
})
registry.register(extension)
await registry.createProvider('test-provider', { headers: { 'X-Custom': 'custom' } })
await registry.createProvider('test-provider', { headers: { 'X-Custom': 'custom' } })
// Should use cache - only created once
expect(secondSettings).toBeUndefined()
// Verify deep merge happened
expect(firstSettings).toEqual({
apiKey: 'default-key',
headers: {
'X-Default': 'value',
'X-Custom': 'custom'
}
})
})
})
describe('ProviderCreationError', () => {
it('should wrap errors in ProviderCreationError', async () => {
registry.register(
new ProviderExtension({
name: 'test-provider',
create: () => {
throw new Error('Creation failed')
}
})
)
try {
await registry.createProvider('test-provider', { apiKey: 'key' })
expect.fail('Should have thrown')
} catch (error) {
expect(error).toBeInstanceOf(ProviderCreationError)
expect((error as ProviderCreationError).providerId).toBe('test-provider')
expect((error as ProviderCreationError).cause.message).toBe('Creation failed')
}
})
})
describe('resolveProviderIdWithMode', () => {
beforeEach(() => {
// 注册带变体的 extension
registry.register(
new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'azure',
aliases: ['azure-openai'],
create: createMockProviderV3,
variants: [
{
suffix: 'responses',
name: 'Azure Responses',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'google',
aliases: ['gemini'],
create: createMockProviderV3
// 没有 variants
})
)
})
it('should resolve base ID + mode to variant ID', () => {
expect(registry.resolveProviderIdWithMode('openai', 'chat')).toBe('openai-chat')
expect(registry.resolveProviderIdWithMode('azure', 'responses')).toBe('azure-responses')
})
it('should support aliases in base ID', () => {
expect(registry.resolveProviderIdWithMode('oai', 'chat')).toBe('openai-chat')
expect(registry.resolveProviderIdWithMode('azure-openai', 'responses')).toBe('azure-responses')
})
it('should return null if extension has no matching variant', () => {
expect(registry.resolveProviderIdWithMode('openai', 'responses')).toBeNull()
expect(registry.resolveProviderIdWithMode('azure', 'chat')).toBeNull()
})
it('should return null if extension has no variants at all', () => {
expect(registry.resolveProviderIdWithMode('google', 'chat')).toBeNull()
})
it('should return null if extension not found', () => {
expect(registry.resolveProviderIdWithMode('non-existent', 'chat')).toBeNull()
})
it('should return resolved base ID when mode is not provided', () => {
expect(registry.resolveProviderIdWithMode('openai')).toBe('openai')
expect(registry.resolveProviderIdWithMode('oai')).toBe('openai')
expect(registry.resolveProviderIdWithMode('gemini')).toBe('google')
})
it('should return null when mode is not provided and extension not found', () => {
expect(registry.resolveProviderIdWithMode('non-existent')).toBeNull()
})
})
describe('parseProviderId', () => {
beforeEach(() => {
registry.register(
new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'azure',
create: createMockProviderV3,
variants: [
{
suffix: 'responses',
name: 'Azure Responses',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'google',
aliases: ['gemini'],
create: createMockProviderV3
})
)
})
it('should parse variant ID to base ID + mode', () => {
expect(registry.parseProviderId('openai-chat')).toEqual({
baseId: 'openai',
mode: 'chat',
isVariant: true
})
expect(registry.parseProviderId('azure-responses')).toEqual({
baseId: 'azure',
mode: 'responses',
isVariant: true
})
})
it('should parse base ID without mode', () => {
expect(registry.parseProviderId('openai')).toEqual({
baseId: 'openai',
isVariant: false
})
expect(registry.parseProviderId('azure')).toEqual({
baseId: 'azure',
isVariant: false
})
expect(registry.parseProviderId('google')).toEqual({
baseId: 'google',
isVariant: false
})
})
it('should resolve aliases to base ID', () => {
expect(registry.parseProviderId('oai')).toEqual({
baseId: 'openai',
isVariant: false
})
expect(registry.parseProviderId('gemini')).toEqual({
baseId: 'google',
isVariant: false
})
})
it('should return null for unknown provider ID', () => {
expect(registry.parseProviderId('non-existent')).toBeNull()
expect(registry.parseProviderId('unknown-variant')).toBeNull()
})
it('should handle multiple variants of same extension', () => {
registry.register(
new ProviderExtension({
name: 'multi-variant',
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Chat',
transform: (provider: ProviderV3) => provider
},
{
suffix: 'responses',
name: 'Responses',
transform: (provider: ProviderV3) => provider
},
{
suffix: 'completions',
name: 'Completions',
transform: (provider: ProviderV3) => provider
}
]
})
)
expect(registry.parseProviderId('multi-variant-chat')).toEqual({
baseId: 'multi-variant',
mode: 'chat',
isVariant: true
})
expect(registry.parseProviderId('multi-variant-responses')).toEqual({
baseId: 'multi-variant',
mode: 'responses',
isVariant: true
})
expect(registry.parseProviderId('multi-variant-completions')).toEqual({
baseId: 'multi-variant',
mode: 'completions',
isVariant: true
})
})
})
describe('Variant Query Methods', () => {
beforeEach(() => {
// 注册带变体的 extensions
registry.register(
new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'azure',
aliases: ['azure-openai'],
create: createMockProviderV3,
variants: [
{
suffix: 'responses',
name: 'Azure Responses',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'google',
aliases: ['gemini'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Google Chat',
transform: (provider: ProviderV3) => provider
}
]
})
)
registry.register(
new ProviderExtension({
name: 'xai',
create: createMockProviderV3
// 没有 variants
})
)
})
describe('isVariant', () => {
it('should return true for variant IDs', () => {
expect(registry.isVariant('openai-chat')).toBe(true)
expect(registry.isVariant('azure-responses')).toBe(true)
expect(registry.isVariant('google-chat')).toBe(true)
})
it('should return false for base provider IDs', () => {
expect(registry.isVariant('openai')).toBe(false)
expect(registry.isVariant('azure')).toBe(false)
expect(registry.isVariant('google')).toBe(false)
expect(registry.isVariant('xai')).toBe(false)
})
it('should return false for aliases', () => {
expect(registry.isVariant('oai')).toBe(false)
expect(registry.isVariant('gemini')).toBe(false)
expect(registry.isVariant('azure-openai')).toBe(false)
})
it('should return false for unknown IDs', () => {
expect(registry.isVariant('unknown')).toBe(false)
expect(registry.isVariant('non-existent-variant')).toBe(false)
})
})
describe('getBaseProviderId', () => {
it('should return base ID for variant IDs', () => {
expect(registry.getBaseProviderId('openai-chat')).toBe('openai')
expect(registry.getBaseProviderId('azure-responses')).toBe('azure')
expect(registry.getBaseProviderId('google-chat')).toBe('google')
})
it('should return base ID for base provider IDs (identity)', () => {
expect(registry.getBaseProviderId('openai')).toBe('openai')
expect(registry.getBaseProviderId('azure')).toBe('azure')
expect(registry.getBaseProviderId('google')).toBe('google')
expect(registry.getBaseProviderId('xai')).toBe('xai')
})
it('should return base ID for aliases', () => {
expect(registry.getBaseProviderId('oai')).toBe('openai')
expect(registry.getBaseProviderId('gemini')).toBe('google')
expect(registry.getBaseProviderId('azure-openai')).toBe('azure')
})
it('should return null for unknown IDs', () => {
expect(registry.getBaseProviderId('unknown')).toBeNull()
expect(registry.getBaseProviderId('non-existent')).toBeNull()
})
})
describe('getVariantMode', () => {
it('should return mode/suffix for variant IDs', () => {
expect(registry.getVariantMode('openai-chat')).toBe('chat')
expect(registry.getVariantMode('azure-responses')).toBe('responses')
expect(registry.getVariantMode('google-chat')).toBe('chat')
})
it('should return null for base provider IDs', () => {
expect(registry.getVariantMode('openai')).toBeNull()
expect(registry.getVariantMode('azure')).toBeNull()
expect(registry.getVariantMode('google')).toBeNull()
expect(registry.getVariantMode('xai')).toBeNull()
})
it('should return null for aliases', () => {
expect(registry.getVariantMode('oai')).toBeNull()
expect(registry.getVariantMode('gemini')).toBeNull()
expect(registry.getVariantMode('azure-openai')).toBeNull()
})
it('should return null for unknown IDs', () => {
expect(registry.getVariantMode('unknown')).toBeNull()
expect(registry.getVariantMode('non-existent-variant')).toBeNull()
})
})
describe('getVariants', () => {
it('should return variant IDs for providers with variants', () => {
expect(registry.getVariants('openai')).toEqual(['openai-chat'])
expect(registry.getVariants('azure')).toEqual(['azure-responses'])
expect(registry.getVariants('google')).toEqual(['google-chat'])
})
it('should return empty array for providers without variants', () => {
expect(registry.getVariants('xai')).toEqual([])
})
it('should support aliases in base ID', () => {
expect(registry.getVariants('oai')).toEqual(['openai-chat'])
expect(registry.getVariants('gemini')).toEqual(['google-chat'])
expect(registry.getVariants('azure-openai')).toEqual(['azure-responses'])
})
it('should return empty array for unknown IDs', () => {
expect(registry.getVariants('unknown')).toEqual([])
expect(registry.getVariants('non-existent')).toEqual([])
})
it('should return all variants for providers with multiple variants', () => {
registry.register(
new ProviderExtension({
name: 'multi-variant',
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Chat',
transform: (provider: ProviderV3) => provider
},
{
suffix: 'responses',
name: 'Responses',
transform: (provider: ProviderV3) => provider
},
{
suffix: 'completions',
name: 'Completions',
transform: (provider: ProviderV3) => provider
}
]
})
)
const variants = registry.getVariants('multi-variant')
expect(variants).toHaveLength(3)
expect(variants).toContain('multi-variant-chat')
expect(variants).toContain('multi-variant-responses')
expect(variants).toContain('multi-variant-completions')
})
})
describe('Integration: All methods working together', () => {
it('should provide consistent information about a variant', () => {
const variantId = 'openai-chat'
// isVariant should confirm it's a variant
expect(registry.isVariant(variantId)).toBe(true)
// getBaseProviderId should extract base ID
expect(registry.getBaseProviderId(variantId)).toBe('openai')
// getVariantMode should extract mode
expect(registry.getVariantMode(variantId)).toBe('chat')
// getVariants should include this variant when querying base ID
const baseId = registry.getBaseProviderId(variantId)!
expect(registry.getVariants(baseId)).toContain(variantId)
})
it('should provide consistent information about a base provider', () => {
const baseId = 'openai'
// isVariant should return false
expect(registry.isVariant(baseId)).toBe(false)
// getBaseProviderId should return itself
expect(registry.getBaseProviderId(baseId)).toBe(baseId)
// getVariantMode should return null
expect(registry.getVariantMode(baseId)).toBeNull()
// getVariants should return its variants
expect(registry.getVariants(baseId)).toEqual(['openai-chat'])
})
it('should provide consistent information about an alias', () => {
const aliasId = 'oai'
// isVariant should return false
expect(registry.isVariant(aliasId)).toBe(false)
// getBaseProviderId should resolve to base ID
expect(registry.getBaseProviderId(aliasId)).toBe('openai')
// getVariantMode should return null
expect(registry.getVariantMode(aliasId)).toBeNull()
// getVariants should work with alias
expect(registry.getVariants(aliasId)).toEqual(['openai-chat'])
})
})
})
describe('getTyped()', () => {
it('should return typed extension for registered providers', () => {
// Register extensions
registry.register(
new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
)
const ext = registry.getTyped('openai')
expect(ext).toBeDefined()
expect(ext?.config.name).toBe('openai')
})
it('should return undefined for unregistered providers', () => {
const ext = registry.getTyped('unknown' as any)
expect(ext).toBeUndefined()
})
it('should preserve type information (compile-time check)', () => {
registry.register(
new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
)
// This test primarily validates compile-time type inference
// Runtime behavior is the same as get()
const ext = registry.getTyped('openai')
expect(ext).toBeDefined()
// Type should be inferred as ProviderExtension<OpenAIProviderSettings, any, any>
// but we can't test types at runtime, only compile-time
})
it('should work with aliases', () => {
registry.register(
new ProviderExtension({
name: 'openai',
aliases: ['oai'],
create: createMockProviderV3
})
)
const ext = registry.getTyped('oai' as any)
expect(ext).toBeDefined()
expect(ext?.config.name).toBe('openai')
})
})
})

View File

@@ -1,526 +0,0 @@
/**
* HubProvider Comprehensive Tests
* Tests hub provider routing, model resolution, and error handling
* Covers multi-provider routing with namespaced model IDs
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { customProvider, wrapProvider } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { createHubProvider, type HubProviderConfig, HubProviderError } from '../HubProvider'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../RegistryManagement'
// Mock dependencies
vi.mock('../RegistryManagement', () => ({
globalRegistryManagement: {
getProvider: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('ai', () => ({
customProvider: vi.fn((config) => config.fallbackProvider),
wrapProvider: vi.fn((config) => config.provider),
jsonSchema: vi.fn((schema) => schema)
}))
describe('HubProvider', () => {
let mockOpenAIProvider: ProviderV3
let mockAnthropicProvider: ProviderV3
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
// Create mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test',
modelId: 'test-model'
})
mockEmbeddingModel = createMockEmbeddingModel({
provider: 'test',
modelId: 'test-embedding'
})
mockImageModel = createMockImageModel({
provider: 'test',
modelId: 'test-image'
})
// Create mock providers
mockOpenAIProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel)
} as ProviderV3
mockAnthropicProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel)
} as ProviderV3
// Setup default mock implementation
vi.mocked(globalRegistryManagement.getProvider).mockImplementation((id) => {
if (id === 'openai') return mockOpenAIProvider
if (id === 'anthropic') return mockAnthropicProvider
return undefined
})
})
describe('Provider Creation', () => {
it('should create hub provider with basic config', () => {
const config: HubProviderConfig = {
hubId: 'test-hub'
}
const provider = createHubProvider(config)
expect(provider).toBeDefined()
expect(customProvider).toHaveBeenCalled()
})
it('should create provider with debug flag', () => {
const config: HubProviderConfig = {
hubId: 'test-hub',
debug: true
}
const provider = createHubProvider(config)
expect(provider).toBeDefined()
})
it('should return ProviderV3 specification', () => {
const config: HubProviderConfig = {
hubId: 'aihubmix'
}
const provider = createHubProvider(config)
expect(provider).toHaveProperty('specificationVersion', 'v3')
expect(provider).toHaveProperty('languageModel')
expect(provider).toHaveProperty('embeddingModel')
expect(provider).toHaveProperty('imageModel')
})
})
describe('Model ID Parsing', () => {
it('should parse valid hub model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4`
const result = provider.languageModel(modelId)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should throw error for invalid model ID format', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const invalidId = 'invalid-id-without-separator'
expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError)
})
it('should throw error for model ID with multiple separators', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`
expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError)
})
it('should throw error for empty model ID', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel('')).toThrow(HubProviderError)
})
it('should throw error for model ID with only separator', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError)
})
})
describe('Language Model Resolution', () => {
it('should route to correct provider for language model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(result).toBe(mockLanguageModel)
})
it('should route different providers correctly', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('anthropic')
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3')
})
it('should wrap provider with wrapProvider', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
})
it('should throw HubProviderError if provider not initialized', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/)
})
it('should include provider ID in error message', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
try {
provider.languageModel(`missing${DEFAULT_SEPARATOR}model`)
expect.fail('Should have thrown HubProviderError')
} catch (error) {
expect(error).toBeInstanceOf(HubProviderError)
const hubError = error as HubProviderError
expect(hubError.providerId).toBe('missing')
expect(hubError.hubId).toBe('test-hub')
}
})
})
describe('Embedding Model Resolution', () => {
it('should route to correct provider for embedding model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
expect(result).toBe(mockEmbeddingModel)
})
it('should handle different embedding providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`)
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1')
})
it('should throw error for uninitialized embedding provider', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError)
})
})
describe('Image Model Resolution', () => {
it('should route to correct provider for image model', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(result).toBe(mockImageModel)
})
it('should handle different image providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`)
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(mockAnthropicProvider.imageModel).toHaveBeenCalledWith('image-gen')
})
})
describe('Special Model Types', () => {
it('should support transcription models', () => {
const mockTranscriptionModel = {
specificationVersion: 'v3',
doTranscribe: vi.fn()
}
const providerWithTranscription = {
...mockOpenAIProvider,
transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithTranscription)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`)
expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1')
expect(result).toBe(mockTranscriptionModel)
})
it('should throw error if provider does not support transcription', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError)
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(
/does not support transcription/
)
})
it('should support speech models', () => {
const mockSpeechModel = {
specificationVersion: 'v3',
doGenerate: vi.fn()
}
const providerWithSpeech = {
...mockOpenAIProvider,
speechModel: vi.fn().mockReturnValue(mockSpeechModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithSpeech)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)
expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1')
expect(result).toBe(mockSpeechModel)
})
it('should throw error if provider does not support speech', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError)
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/)
})
it('should support reranking models', () => {
const mockRerankingModel = {
specificationVersion: 'v3',
doRerank: vi.fn()
}
const providerWithReranking = {
...mockOpenAIProvider,
rerankingModel: vi.fn().mockReturnValue(mockRerankingModel)
} as ProviderV3
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithReranking)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`)
expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1')
expect(result).toBe(mockRerankingModel)
})
it('should throw error if provider does not support reranking', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError)
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/)
})
})
describe('Error Handling', () => {
it('should create HubProviderError with all properties', () => {
const originalError = new Error('Original error')
const error = new HubProviderError('Test message', 'test-hub', 'test-provider', originalError)
expect(error.message).toBe('Test message')
expect(error.hubId).toBe('test-hub')
expect(error.providerId).toBe('test-provider')
expect(error.originalError).toBe(originalError)
expect(error.name).toBe('HubProviderError')
})
it('should create HubProviderError without optional parameters', () => {
const error = new HubProviderError('Test message', 'test-hub')
expect(error.message).toBe('Test message')
expect(error.hubId).toBe('test-hub')
expect(error.providerId).toBeUndefined()
expect(error.originalError).toBeUndefined()
})
it('should wrap provider errors in HubProviderError', () => {
const providerError = new Error('Provider failed')
vi.mocked(globalRegistryManagement.getProvider).mockImplementation(() => {
throw providerError
})
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
try {
provider.languageModel(`failing${DEFAULT_SEPARATOR}model`)
expect.fail('Should have thrown HubProviderError')
} catch (error) {
expect(error).toBeInstanceOf(HubProviderError)
const hubError = error as HubProviderError
expect(hubError.originalError).toBe(providerError)
expect(hubError.message).toContain('Failed to get provider')
}
})
it('should handle null provider from registry', () => {
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(null as any)
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
})
})
describe('Multi-Provider Scenarios', () => {
it('should handle sequential calls to different providers', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2)
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1)
})
it('should handle mixed model types from same provider', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
it('should cache provider lookups', () => {
const config: HubProviderConfig = { hubId: 'aihubmix' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
// Should call getProvider twice (once per model call)
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(2)
})
})
describe('Provider Wrapping', () => {
it('should wrap all providers with empty middleware', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(wrapProvider).toHaveBeenCalledWith({
provider: mockOpenAIProvider,
languageModelMiddleware: []
})
})
it('should wrap providers for all model types', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
expect(wrapProvider).toHaveBeenCalledTimes(3)
})
})
describe('Type Safety', () => {
it('should return properly typed LanguageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
expect(result).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', () => {
const config: HubProviderConfig = { hubId: 'test-hub' }
const provider = createHubProvider(config) as ProviderV3
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
expect(result.specificationVersion).toBe('v3')
expect(result).toHaveProperty('doGenerate')
})
})
})

View File

@@ -0,0 +1,805 @@
/**
* ProviderExtension 单元测试
*/
import type { ProviderV3 } from '@ai-sdk/provider'
import { createMockProviderV3 } from '@test-utils'
import { describe, expect, it, vi } from 'vitest'
import { ProviderExtension } from '../core/ProviderExtension'
describe('ProviderExtension', () => {
describe('Static create() Method', () => {
it('should create extension with config object', () => {
const extension = ProviderExtension.create({
name: 'test-provider',
create: createMockProviderV3
})
expect(extension).toBeInstanceOf(ProviderExtension)
expect(extension.config.name).toBe('test-provider')
})
it('should create extension with config function', () => {
const configFn = vi.fn(() => ({
name: 'test-provider',
create: createMockProviderV3,
defaultOptions: { apiKey: 'test-key' }
}))
const extension = ProviderExtension.create(configFn)
expect(configFn).toHaveBeenCalledOnce()
expect(extension).toBeInstanceOf(ProviderExtension)
expect(extension.config.name).toBe('test-provider')
expect(extension.config.defaultOptions).toEqual({ apiKey: 'test-key' })
})
it('should support type inference with generics', () => {
interface TestSettings {
apiKey: string
baseURL?: string
name: string
}
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createMockProviderV3 as any, // Type assertion needed as mock has different signature
defaultOptions: {
apiKey: 'test-key'
}
})
expect(extension.config.name).toBe('test-provider')
})
it('should allow delayed config resolution with function', () => {
let envVariable = 'initial-key'
const extension = ProviderExtension.create(() => ({
name: 'dynamic-provider',
create: createMockProviderV3,
defaultOptions: {
apiKey: envVariable // Captured at creation time
}
}))
expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' })
// Changing variable doesn't affect already created extension
envVariable = 'changed-key'
expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' })
})
it('should validate config from function same as from object', async () => {
expect(() => {
ProviderExtension.create(() => ({
name: '', // Invalid
create: createMockProviderV3
}))
}).toThrow('name is required')
// Note: create/import validation happens at runtime in createProvider(), not in constructor
// Extension can be created without create/import, but createProvider() will throw
const extension = ProviderExtension.create(
() =>
({
name: 'test-provider'
// Missing create
}) as any
)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
})
describe('Constructor Validation', () => {
it('should throw error if name is missing', () => {
expect(() => {
new ProviderExtension({
name: '',
create: createMockProviderV3
})
}).toThrow('name is required')
})
it('should throw error at runtime if neither create nor import is provided', async () => {
// Constructor doesn't validate create/import - validation happens at runtime
const extension = new ProviderExtension({
name: 'test-provider'
} as any)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
it('should throw error at runtime if import is provided without creatorFunctionName', async () => {
// Constructor doesn't validate creatorFunctionName - validation happens at runtime
const extension = new ProviderExtension({
name: 'test-provider',
import: async () => ({})
} as any)
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
})
it('should create extension with valid config', () => {
const extension = new ProviderExtension({
name: 'test-provider',
create: createMockProviderV3
})
expect(extension.config.name).toBe('test-provider')
})
})
describe('Configure Method', () => {
it('should return new instance with merged settings', () => {
const original = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any,
defaultOptions: { apiKey: 'original-key' }
})
const configured = original.configure({ baseURL: 'https://api.test.com' })
// 原实例不变
expect(original.config.defaultOptions).toEqual({ apiKey: 'original-key' })
// 新实例合并配置
expect(configured.config.defaultOptions).toEqual({
apiKey: 'original-key',
baseURL: 'https://api.test.com'
})
// 是新实例
expect(configured).not.toBe(original)
})
it('should override existing options', () => {
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any,
defaultOptions: { apiKey: 'old-key', timeout: 5000 }
})
const configured = extension.configure({ apiKey: 'new-key' })
expect(configured.config.defaultOptions).toEqual({
apiKey: 'new-key',
timeout: 5000
})
})
})
describe('getProviderIds', () => {
it('should return only main ID when no aliases or variants', () => {
const extension = new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
expect(extension.getProviderIds()).toEqual(['openai'])
})
it('should include aliases', () => {
const extension = new ProviderExtension({
name: 'openrouter',
aliases: ['or', 'open-router'],
create: createMockProviderV3
})
expect(extension.getProviderIds()).toEqual(['openrouter', 'or', 'open-router'])
})
it('should include variant IDs', () => {
const extension = new ProviderExtension({
name: 'openai',
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (provider) => provider
}
]
})
expect(extension.getProviderIds()).toEqual(['openai', 'openai-chat'])
})
it('should include both aliases and variant IDs', () => {
const extension = new ProviderExtension({
name: 'azure',
aliases: ['az'],
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Azure Chat',
transform: (provider) => provider
},
{
suffix: 'responses',
name: 'Azure Responses',
transform: (provider) => provider
}
]
})
expect(extension.getProviderIds()).toEqual(['azure', 'az', 'azure-chat', 'azure-responses'])
})
})
describe('hasProviderId', () => {
it('should return true for main ID', () => {
const extension = new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
expect(extension.hasProviderId('openai')).toBe(true)
})
it('should return true for alias', () => {
const extension = new ProviderExtension({
name: 'openrouter',
aliases: ['or'],
create: createMockProviderV3
})
expect(extension.hasProviderId('or')).toBe(true)
})
it('should return true for variant ID', () => {
const extension = new ProviderExtension({
name: 'openai',
create: createMockProviderV3,
variants: [
{
suffix: 'chat',
name: 'Chat',
transform: (provider) => provider
}
]
})
expect(extension.hasProviderId('openai-chat')).toBe(true)
})
it('should return false for non-existent ID', () => {
const extension = new ProviderExtension({
name: 'openai',
create: createMockProviderV3
})
expect(extension.hasProviderId('anthropic')).toBe(false)
})
})
describe('getVariant', () => {
it('should return variant by suffix', () => {
const chatVariant = {
suffix: 'chat',
name: 'Chat Mode',
transform: (provider: ProviderV3) => provider
}
const extension = new ProviderExtension({
name: 'test',
create: createMockProviderV3,
variants: [chatVariant]
})
expect(extension.getVariant('chat')).toEqual(chatVariant)
})
it('should return undefined for non-existent variant', () => {
const extension = new ProviderExtension({
name: 'test',
create: createMockProviderV3,
variants: []
})
expect(extension.getVariant('chat')).toBeUndefined()
})
it('should return undefined when no variants configured', () => {
const extension = new ProviderExtension({
name: 'test',
create: createMockProviderV3
})
expect(extension.getVariant('chat')).toBeUndefined()
})
})
describe('Type Safety', () => {
interface TestSettings {
apiKey: string
baseURL?: string
timeout?: number
}
it('should maintain type safety with generics', () => {
const extension = new ProviderExtension<TestSettings>({
name: 'typed-provider',
create: ((settings: any) => {
// TypeScript should infer settings as TestSettings
expect(settings?.apiKey).toBeDefined()
return createMockProviderV3()
}) as any,
defaultOptions: {
apiKey: 'test-key',
timeout: 5000
}
})
const configured = extension.configure({
baseURL: 'https://api.test.com'
// TypeScript should catch invalid properties here
})
expect(configured.config.defaultOptions?.baseURL).toBe('https://api.test.com')
})
})
describe('Options Getter', () => {
it('should return readonly frozen options', () => {
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any,
defaultOptions: { apiKey: 'test-key', timeout: 5000 }
})
const options = extension.options
expect(options).toEqual({ apiKey: 'test-key', timeout: 5000 })
expect(Object.isFrozen(options)).toBe(true)
})
})
describe('Deep Merge in Configure', () => {
it('should deep merge nested objects', () => {
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any,
defaultOptions: {
apiKey: 'key1',
headers: {
'X-Custom-Header': 'value1',
Authorization: 'Bearer token1'
},
retry: {
maxAttempts: 3,
backoff: 1000
}
}
})
const configured = extension.configure({
headers: {
Authorization: 'Bearer new-token'
},
retry: {
maxAttempts: 5
}
})
expect(configured.config.defaultOptions).toEqual({
apiKey: 'key1',
headers: {
'X-Custom-Header': 'value1',
Authorization: 'Bearer new-token'
},
retry: {
maxAttempts: 5,
backoff: 1000
}
})
})
it('should not mutate original extension', () => {
const original = new ProviderExtension<any>({
name: 'test-provider',
create: createMockProviderV3 as any,
defaultOptions: {
nested: { value: 'original' }
}
})
const configured = original.configure({
nested: { value: 'modified' }
})
expect(original.config.defaultOptions).toEqual({
nested: { value: 'original' }
})
expect(configured.config.defaultOptions).toEqual({
nested: { value: 'modified' }
})
})
})
describe('Instance Caching (Phase 1)', () => {
interface TestSettings {
apiKey: string
baseURL?: string
}
it('should cache and reuse instance with same settings', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createFn as any
})
const settings = { apiKey: 'test-key', baseURL: 'https://api.test.com' }
const instance1 = await extension.createProvider(settings)
const instance2 = await extension.createProvider(settings)
// Should return the same instance
expect(instance1).toBe(instance2)
// Should only call create once
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should create new instance with different settings', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createFn as any
})
const settings1 = { apiKey: 'key1' }
const settings2 = { apiKey: 'key2' }
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
// Should return different instances
expect(instance1).not.toBe(instance2)
// Should call create twice
expect(createFn).toHaveBeenCalledTimes(2)
})
it('should handle undefined settings correctly', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createFn as any
})
const instance1 = await extension.createProvider()
const instance2 = await extension.createProvider()
const instance3 = await extension.createProvider(undefined)
// All should be the same instance (undefined settings)
expect(instance1).toBe(instance2)
expect(instance1).toBe(instance3)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should compute stable hash for same settings in different order', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
// Same settings but different property order
const settings1 = { apiKey: 'key', baseURL: 'url', timeout: 5000 }
const settings2 = { timeout: 5000, apiKey: 'key', baseURL: 'url' }
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
// Should recognize as same settings
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should merge with default options before hashing', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any,
defaultOptions: {
apiKey: 'default-key',
timeout: 5000
}
})
const instance1 = await extension.createProvider({ baseURL: 'url' })
const instance2 = await extension.createProvider({ baseURL: 'url', apiKey: 'default-key', timeout: 5000 })
// Should be same after merging with defaults
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should handle nested objects in settings', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const settings1 = {
apiKey: 'key',
headers: { Authorization: 'Bearer token', 'X-Custom': 'value' }
}
const settings2 = {
apiKey: 'key',
headers: { 'X-Custom': 'value', Authorization: 'Bearer token' }
}
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
// Should recognize as same (order doesn't matter)
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should support variant suffix parameter', async () => {
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createMockProviderV3 as any,
variants: [
{
suffix: 'chat',
name: 'Test Chat',
transform: (provider) => provider
}
]
})
const settings = { apiKey: 'test-key' }
// Should work when providing a valid variant suffix
await expect(extension.createProvider(settings, 'chat')).resolves.toBeDefined()
// Should throw for unknown variant suffix
await expect(extension.createProvider(settings, 'unknown')).rejects.toThrow('variant "unknown" not found')
})
it('should support dynamic import providers', async () => {
const mockModule = {
createProvider: vi.fn(createMockProviderV3)
}
const extension = new ProviderExtension<TestSettings>({
name: 'lazy-provider',
import: async () => mockModule,
creatorFunctionName: 'createProvider'
})
const instance1 = await extension.createProvider({ apiKey: 'key' })
const instance2 = await extension.createProvider({ apiKey: 'key' })
expect(instance1).toBe(instance2)
expect(mockModule.createProvider).toHaveBeenCalledTimes(1)
})
it('should throw error if creatorFunctionName not found in module', async () => {
const mockModule = {
wrongName: vi.fn(createMockProviderV3)
}
const extension = new ProviderExtension<TestSettings>({
name: 'lazy-provider',
import: async () => mockModule,
creatorFunctionName: 'createProvider'
})
await expect(extension.createProvider({ apiKey: 'key' })).rejects.toThrow(
'creatorFunctionName "createProvider" not found'
)
})
it('should deduplicate concurrent requests with same settings', async () => {
const createFn = vi.fn(async () => {
// Simulate async delay
await new Promise((resolve) => setTimeout(resolve, 10))
return createMockProviderV3()
})
const extension = new ProviderExtension<TestSettings>({
name: 'test-provider',
create: createFn as any
})
const settings = { apiKey: 'test-key' }
// Fire multiple concurrent requests
const [instance1, instance2, instance3] = await Promise.all([
extension.createProvider(settings),
extension.createProvider(settings),
extension.createProvider(settings)
])
// All concurrent requests should return the same instance
expect(instance1).toBe(instance2)
expect(instance2).toBe(instance3)
// Creator should only be called once
expect(createFn).toHaveBeenCalledTimes(1)
// Verify subsequent sequential calls also use cache
const instance4 = await extension.createProvider(settings)
expect(instance4).toBe(instance1)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should handle arrays in settings', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const settings1 = { apiKey: 'key', tags: ['a', 'b', 'c'] }
const settings2 = { apiKey: 'key', tags: ['a', 'b', 'c'] }
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should differentiate settings with different array values', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const settings1 = { apiKey: 'key', tags: ['a', 'b'] }
const settings2 = { apiKey: 'key', tags: ['a', 'c'] }
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
expect(instance1).not.toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(2)
})
})
describe('Cache Key Correctness (no hash collisions)', () => {
it('should differentiate settings with different API keys', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const instance1 = await extension.createProvider({ apiKey: 'sk-key-A', baseURL: 'https://api.example.com' })
const instance2 = await extension.createProvider({ apiKey: 'sk-key-B', baseURL: 'https://api.example.com' })
expect(instance1).not.toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(2)
})
it('should differentiate settings with different base URLs', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const instance1 = await extension.createProvider({ apiKey: 'same-key', baseURL: 'https://api-a.example.com' })
const instance2 = await extension.createProvider({ apiKey: 'same-key', baseURL: 'https://api-b.example.com' })
expect(instance1).not.toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(2)
})
it('should treat structurally identical settings as the same regardless of construction', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
// Construct settings in completely different ways
const base = { apiKey: 'key', baseURL: 'url' }
const settings1 = { ...base, headers: { Authorization: 'Bearer tok' } }
const settings2 = JSON.parse(JSON.stringify(settings1))
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should differentiate same-base-provider with different variant suffixes', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'azure',
create: createFn as any,
variants: [
{ suffix: 'chat', name: 'Chat', transform: (p: any) => ({ ...p, _variant: 'chat' }) },
{ suffix: 'responses', name: 'Responses', transform: (p: any) => ({ ...p, _variant: 'responses' }) }
]
})
const settings = { apiKey: 'key' }
const chatInstance = await extension.createProvider(settings, 'chat')
const responsesInstance = await extension.createProvider(settings, 'responses')
const baseInstance = await extension.createProvider(settings)
// Variant instances should be different from each other
expect(chatInstance).not.toBe(responsesInstance)
// Base provider is cached when first variant is created, so baseInstance
// is the same object as the unwrapped base (reused across variants)
expect(chatInstance).not.toBe(baseInstance)
expect(responsesInstance).not.toBe(baseInstance)
// createFn called once for 'chat' variant (also caches base), once for 'responses' (reuses cached base)
// base provider request reuses the cached instance from the first variant creation
expect(createFn).toHaveBeenCalledTimes(2)
})
it('should handle settings with functions by treating them uniformly', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
const fetchFn = () => Promise.resolve(new Response())
const settings1 = { apiKey: 'key', fetch: fetchFn }
const settings2 = { apiKey: 'key', fetch: fetchFn }
const instance1 = await extension.createProvider(settings1)
const instance2 = await extension.createProvider(settings2)
// Same function reference → same serialization → same cache hit
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
})
it('should distinguish settings with null vs missing keys', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
// null/undefined serialize identically via stableStringify, so same cache key
const instance1 = await extension.createProvider({ apiKey: 'key', extra: null })
const instance2 = await extension.createProvider({ apiKey: 'key', extra: null })
expect(instance1).toBe(instance2)
expect(createFn).toHaveBeenCalledTimes(1)
// But a truly different value should create a new instance
const instance3 = await extension.createProvider({ apiKey: 'key', extra: 'value' })
expect(instance3).not.toBe(instance1)
expect(createFn).toHaveBeenCalledTimes(2)
})
it('should not collide on similarly-structured but different settings', async () => {
const createFn = vi.fn(createMockProviderV3)
const extension = new ProviderExtension<any>({
name: 'test-provider',
create: createFn as any
})
// These have similar structure but different values - old DJB2 hash could collide
const settingsA = { apiKey: 'aaaa1111', baseURL: 'https://host-a.com', timeout: 3000 }
const settingsB = { apiKey: 'bbbb2222', baseURL: 'https://host-b.com', timeout: 3000 }
const settingsC = { apiKey: 'cccc3333', baseURL: 'https://host-c.com', timeout: 3000 }
const instanceA = await extension.createProvider(settingsA)
const instanceB = await extension.createProvider(settingsB)
const instanceC = await extension.createProvider(settingsC)
expect(instanceA).not.toBe(instanceB)
expect(instanceB).not.toBe(instanceC)
expect(instanceA).not.toBe(instanceC)
expect(createFn).toHaveBeenCalledTimes(3)
})
})
})

View File

@@ -1,562 +0,0 @@
/**
* RegistryManagement Comprehensive Tests
* Tests provider registry management, model resolution, and alias handling
* Covers registration, retrieval, and cleanup operations
*/
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import { createProviderRegistry } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
import { DEFAULT_SEPARATOR, RegistryManagement } from '../RegistryManagement'
// Mock AI SDK
vi.mock('ai', () => ({
createProviderRegistry: vi.fn(),
jsonSchema: vi.fn((schema) => schema)
}))
describe('RegistryManagement', () => {
let registry: RegistryManagement
let mockProvider: ProviderV3
let mockLanguageModel: LanguageModelV3
let mockEmbeddingModel: EmbeddingModelV3
let mockImageModel: ImageModelV3
beforeEach(() => {
vi.clearAllMocks()
// Create mock models using global utilities
mockLanguageModel = createMockLanguageModel({
provider: 'test',
modelId: 'test-model'
})
mockEmbeddingModel = createMockEmbeddingModel({
provider: 'test',
modelId: 'test-embedding'
})
mockImageModel = createMockImageModel({
provider: 'test',
modelId: 'test-image'
})
// Create mock provider
mockProvider = {
specificationVersion: 'v3',
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel),
transcriptionModel: vi.fn(),
speechModel: vi.fn()
} as ProviderV3
// Setup mock registry
const mockRegistry = {
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
imageModel: vi.fn().mockReturnValue(mockImageModel),
transcriptionModel: vi.fn(),
speechModel: vi.fn()
}
vi.mocked(createProviderRegistry).mockReturnValue(mockRegistry as any)
registry = new RegistryManagement()
})
describe('Constructor and Initialization', () => {
it('should create registry with default separator', () => {
const reg = new RegistryManagement()
expect(reg).toBeInstanceOf(RegistryManagement)
expect(reg.hasProviders()).toBe(false)
})
it('should create registry with custom separator', () => {
const customSeparator = ':'
const reg = new RegistryManagement({ separator: customSeparator })
expect(reg).toBeInstanceOf(RegistryManagement)
})
it('should start with empty provider list', () => {
expect(registry.getRegisteredProviders()).toEqual([])
})
})
describe('Provider Registration', () => {
it('should register a provider', () => {
registry.registerProvider('openai', mockProvider)
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.hasProviders()).toBe(true)
})
it('should register multiple providers', () => {
const provider2 = { ...mockProvider }
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', provider2)
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('anthropic')).toBe(provider2)
})
it('should return this for chaining', () => {
const result = registry.registerProvider('openai', mockProvider)
expect(result).toBe(registry)
})
it('should rebuild registry after registration', () => {
registry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledWith(
expect.objectContaining({
openai: mockProvider
}),
{ separator: DEFAULT_SEPARATOR }
)
})
it('should register provider with aliases', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('gpt')).toBe(mockProvider)
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
})
it('should track aliases separately', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
expect(registry.isAlias('gpt')).toBe(true)
expect(registry.isAlias('openai')).toBe(false)
})
it('should handle multiple aliases for same provider', () => {
const aliases = ['alias1', 'alias2', 'alias3']
registry.registerProvider('provider', mockProvider, aliases)
aliases.forEach((alias) => {
expect(registry.getProvider(alias)).toBe(mockProvider)
expect(registry.isAlias(alias)).toBe(true)
})
})
})
describe('Bulk Registration', () => {
it('should register multiple providers at once', () => {
const providers = {
openai: mockProvider,
anthropic: { ...mockProvider },
google: { ...mockProvider }
}
registry.registerProviders(providers)
expect(registry.getProvider('openai')).toBe(providers.openai)
expect(registry.getProvider('anthropic')).toBe(providers.anthropic)
expect(registry.getProvider('google')).toBe(providers.google)
})
it('should return this for chaining', () => {
const result = registry.registerProviders({ openai: mockProvider })
expect(result).toBe(registry)
})
})
describe('Provider Retrieval', () => {
beforeEach(() => {
registry.registerProvider('openai', mockProvider)
})
it('should retrieve registered provider', () => {
const provider = registry.getProvider('openai')
expect(provider).toBe(mockProvider)
})
it('should return undefined for unregistered provider', () => {
const provider = registry.getProvider('nonexistent')
expect(provider).toBeUndefined()
})
it('should retrieve provider by alias', () => {
registry.registerProvider('anthropic', mockProvider, ['claude'])
const provider = registry.getProvider('claude')
expect(provider).toBe(mockProvider)
})
it('should get list of all registered providers', () => {
registry.registerProvider('anthropic', mockProvider)
registry.registerProvider('google', mockProvider, ['gemini'])
const providers = registry.getRegisteredProviders()
expect(providers).toContain('openai')
expect(providers).toContain('anthropic')
expect(providers).toContain('google')
expect(providers).toContain('gemini') // Aliases included
})
})
describe('Provider Unregistration', () => {
it('should unregister provider', () => {
registry.registerProvider('openai', mockProvider)
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
})
it('should unregister provider with all its aliases', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
expect(registry.getProvider('gpt')).toBeUndefined()
expect(registry.getProvider('chatgpt')).toBeUndefined()
})
it('should unregister only alias when alias is removed', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('gpt')
expect(registry.getProvider('openai')).toBe(mockProvider)
expect(registry.getProvider('gpt')).toBeUndefined()
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
})
it('should handle unregistering non-existent provider', () => {
expect(() => registry.unregisterProvider('nonexistent')).not.toThrow()
})
it('should return this for chaining', () => {
registry.registerProvider('openai', mockProvider)
const result = registry.unregisterProvider('openai')
expect(result).toBe(registry)
})
it('should rebuild registry after unregistration', () => {
registry.registerProvider('openai', mockProvider)
vi.clearAllMocks()
registry.unregisterProvider('openai')
// Should rebuild with empty providers
expect(createProviderRegistry).not.toHaveBeenCalled() // No rebuild when empty
})
})
describe('Model Resolution', () => {
beforeEach(() => {
registry.registerProvider('openai', mockProvider)
})
it('should resolve language model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` as any
const result = registry.languageModel(modelId)
expect(result).toBe(mockLanguageModel)
})
it('should resolve embedding model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}text-embedding-3-small` as any
const result = registry.embeddingModel(modelId)
expect(result).toBe(mockEmbeddingModel)
})
it('should resolve image model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}dall-e-3` as any
const result = registry.imageModel(modelId)
expect(result).toBe(mockImageModel)
})
it('should resolve transcription model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}whisper-1` as any
registry.transcriptionModel(modelId)
// Verify it calls through to the mock registry
expect(createProviderRegistry).toHaveBeenCalled()
})
it('should resolve speech model', () => {
const modelId = `openai${DEFAULT_SEPARATOR}tts-1` as any
registry.speechModel(modelId)
expect(createProviderRegistry).toHaveBeenCalled()
})
it('should throw error when no providers registered', () => {
const emptyRegistry = new RegistryManagement()
expect(() => emptyRegistry.languageModel('openai|gpt-4' as any)).toThrow('No providers registered')
})
})
describe('Alias Management', () => {
it('should resolve provider ID from alias', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
const realId = registry.resolveProviderId('gpt')
expect(realId).toBe('openai')
})
it('should return same ID if not an alias', () => {
registry.registerProvider('openai', mockProvider)
const realId = registry.resolveProviderId('openai')
expect(realId).toBe('openai')
})
it('should check if ID is alias', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
expect(registry.isAlias('gpt')).toBe(true)
expect(registry.isAlias('openai')).toBe(false)
})
it('should get all alias mappings', () => {
const provider2 = { ...mockProvider }
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.registerProvider('anthropic', provider2, ['claude'])
const aliases = registry.getAllAliases()
// Check that all aliases are present
expect(aliases['gpt']).toBe('openai')
expect(aliases['chatgpt']).toBe('openai')
expect(aliases['claude']).toBe('anthropic')
})
it('should return empty object when no aliases', () => {
registry.registerProvider('openai', mockProvider)
const aliases = registry.getAllAliases()
expect(aliases).toEqual({})
})
})
describe('Registry State', () => {
it('should check if has providers', () => {
expect(registry.hasProviders()).toBe(false)
registry.registerProvider('openai', mockProvider)
expect(registry.hasProviders()).toBe(true)
})
it('should clear all providers', () => {
registry.registerProvider('openai', mockProvider, ['gpt'])
registry.registerProvider('anthropic', mockProvider)
registry.clear()
expect(registry.hasProviders()).toBe(false)
expect(registry.getRegisteredProviders()).toEqual([])
expect(registry.getAllAliases()).toEqual({})
})
it('should return this after clear for chaining', () => {
const result = registry.clear()
expect(result).toBe(registry)
})
})
describe('Registry Rebuilding', () => {
it('should rebuild registry when provider added', () => {
registry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
})
it('should rebuild registry when provider removed', () => {
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', mockProvider)
vi.clearAllMocks()
registry.unregisterProvider('openai')
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
})
it('should set registry to null when all providers removed', () => {
registry.registerProvider('openai', mockProvider)
registry.unregisterProvider('openai')
expect(() => registry.languageModel('any|model' as any)).toThrow('No providers registered')
})
it('should rebuild with correct separator', () => {
const customRegistry = new RegistryManagement({ separator: ':' })
customRegistry.registerProvider('openai', mockProvider)
expect(createProviderRegistry).toHaveBeenCalledWith(expect.any(Object), { separator: ':' })
})
})
describe('Global Registry Instance', () => {
it('should have a global instance with default separator', async () => {
const module = await import('../RegistryManagement')
expect(module.globalRegistryManagement).toBeInstanceOf(RegistryManagement)
})
it('should have DEFAULT_SEPARATOR exported', () => {
expect(DEFAULT_SEPARATOR).toBe('|')
})
})
describe('Edge Cases', () => {
it('should handle registering same provider twice', () => {
registry.registerProvider('openai', mockProvider)
const provider2 = { ...mockProvider }
registry.registerProvider('openai', provider2)
expect(registry.getProvider('openai')).toBe(provider2)
})
it('should handle alias conflicts (first wins)', () => {
registry.registerProvider('provider1', mockProvider, ['shared-alias'])
registry.registerProvider('provider2', mockProvider, ['shared-alias'])
// First registered alias wins (the implementation doesn't override)
expect(registry.resolveProviderId('shared-alias')).toBe('provider1')
})
it('should handle empty alias array', () => {
registry.registerProvider('openai', mockProvider, [])
expect(registry.getAllAliases()).toEqual({})
})
it('should handle null registry operations gracefully', () => {
const emptyRegistry = new RegistryManagement()
expect(() => emptyRegistry.languageModel('test|model' as any)).toThrow('No providers registered')
expect(() => emptyRegistry.embeddingModel('test|embed' as any)).toThrow('No providers registered')
expect(() => emptyRegistry.imageModel('test|image' as any)).toThrow('No providers registered')
})
it('should handle special characters in provider IDs', () => {
const specialIds = ['provider-1', 'provider_2', 'provider.3', 'provider:4']
specialIds.forEach((id) => {
registry.registerProvider(id, mockProvider)
expect(registry.getProvider(id)).toBe(mockProvider)
})
})
})
describe('Concurrent Operations', () => {
it('should handle concurrent registrations', () => {
const promises = [
Promise.resolve(registry.registerProvider('provider1', mockProvider)),
Promise.resolve(registry.registerProvider('provider2', mockProvider)),
Promise.resolve(registry.registerProvider('provider3', mockProvider))
]
return Promise.all(promises).then(() => {
expect(registry.getRegisteredProviders()).toHaveLength(3)
})
})
it('should handle mixed operations', () => {
registry.registerProvider('openai', mockProvider)
registry.registerProvider('anthropic', mockProvider)
const provider1 = registry.getProvider('openai')
registry.unregisterProvider('anthropic')
const provider2 = registry.getProvider('openai')
expect(provider1).toBe(provider2)
})
})
describe('Type Safety', () => {
it('should enforce model ID format with template literal types', () => {
registry.registerProvider('openai', mockProvider)
// These should be type-safe
const validId = 'openai|gpt-4' as `${string}${typeof DEFAULT_SEPARATOR}${string}`
expect(() => registry.languageModel(validId)).not.toThrow()
})
it('should return properly typed LanguageModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.languageModel('openai|gpt-4' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doGenerate')
expect(model).toHaveProperty('doStream')
})
it('should return properly typed EmbeddingModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.embeddingModel('openai|ada-002' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doEmbed')
})
it('should return properly typed ImageModelV3', () => {
registry.registerProvider('openai', mockProvider)
const model = registry.imageModel('openai|dall-e-3' as any)
expect(model.specificationVersion).toBe('v3')
expect(model).toHaveProperty('doGenerate')
})
})
describe('Memory Management', () => {
it('should properly clean up on clear', () => {
registry.registerProvider('p1', mockProvider, ['a1'])
registry.registerProvider('p2', mockProvider, ['a2'])
registry.clear()
expect(registry.getRegisteredProviders()).toHaveLength(0)
expect(Object.keys(registry.getAllAliases())).toHaveLength(0)
})
it('should properly clean up on unregister', () => {
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
registry.unregisterProvider('openai')
expect(registry.getProvider('openai')).toBeUndefined()
expect(registry.isAlias('gpt')).toBe(false)
expect(registry.isAlias('chatgpt')).toBe(false)
})
})
})

View File

@@ -1,662 +0,0 @@
/**
* 测试真正的 AiProviderRegistry 功能
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'
// 模拟 AI SDK
vi.mock('@ai-sdk/openai', () => ({
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
}))
vi.mock('@ai-sdk/anthropic', () => ({
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
}))
vi.mock('@ai-sdk/azure', () => ({
createAzure: vi.fn(() => ({
name: 'azure-mock',
languageModel: vi.fn((modelId: string) => ({ mode: 'default', modelId })),
chat: vi.fn((modelId: string) => ({ mode: 'chat', modelId })),
responses: vi.fn((modelId: string) => ({ mode: 'responses', modelId })),
embeddingModel: vi.fn(),
imageModel: vi.fn(),
transcriptionModel: vi.fn(),
speechModel: vi.fn()
}))
}))
vi.mock('@ai-sdk/deepseek', () => ({
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
}))
vi.mock('@ai-sdk/google', () => ({
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
}))
vi.mock('@ai-sdk/openai-compatible', () => ({
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
}))
vi.mock('@ai-sdk/xai', () => ({
createXai: vi.fn(() => ({ name: 'xai-mock' }))
}))
import {
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
hasInitializedProviders,
hasProviderConfig,
hasProviderConfigByAlias,
isProviderConfigAlias,
ProviderInitializationError,
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
registerProviderConfig,
resolveProviderConfigId
} from '../registry'
import type { ProviderConfig } from '../schemas'
describe('Provider Registry 功能测试', () => {
beforeEach(() => {
// 清理状态
cleanup()
})
describe('基础功能', () => {
it('能够获取支持的 providers 列表', () => {
const providers = getSupportedProviders()
expect(Array.isArray(providers)).toBe(true)
expect(providers.length).toBeGreaterThan(0)
// 检查返回的数据结构
providers.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
})
// 包含基础 providers
const providerIds = providers.map((p) => p.id)
expect(providerIds).toContain('openai')
expect(providerIds).toContain('anthropic')
expect(providerIds).toContain('google')
})
it('能够获取已初始化的 providers', () => {
// 初始状态下没有已初始化的 providers
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('能够访问全局注册管理器', () => {
expect(providerRegistry).toBeDefined()
expect(typeof providerRegistry.clear).toBe('function')
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
expect(typeof providerRegistry.hasProviders).toBe('function')
})
it('能够获取语言模型', () => {
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
})
})
describe('Provider 配置注册', () => {
it('能够注册自定义 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(() => ({ name: 'custom' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig('custom-provider')).toBe(true)
expect(getProviderConfig('custom-provider')).toEqual(config)
})
it('能够注册带别名的 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider-with-aliases',
name: 'Custom Provider with Aliases',
creator: vi.fn(() => ({ name: 'custom-aliased' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'alias-2']
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
})
it('拒绝无效的配置', () => {
// 缺少必要字段
const invalidConfig = {
id: 'invalid-provider'
// 缺少 name, creator 等
}
const success = registerProviderConfig(invalidConfig as any)
expect(success).toBe(false)
})
it('能够批量注册 provider 配置', () => {
const configs: ProviderConfig[] = [
{
id: 'provider-1',
name: 'Provider 1',
creator: vi.fn(() => ({ name: 'provider-1' })),
supportsImageGeneration: false
},
{
id: 'provider-2',
name: 'Provider 2',
creator: vi.fn(() => ({ name: 'provider-2' })),
supportsImageGeneration: true
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2) // 只有前两个成功
expect(hasProviderConfig('provider-1')).toBe(true)
expect(hasProviderConfig('provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('能够获取所有配置和别名信息', () => {
// 注册一些配置
registerProviderConfig({
id: 'test-provider',
name: 'Test Provider',
creator: vi.fn(),
supportsImageGeneration: false,
aliases: ['test-alias']
})
const allConfigs = getAllProviderConfigs()
expect(Array.isArray(allConfigs)).toBe(true)
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
const aliases = getAllProviderConfigAliases()
expect(aliases['test-alias']).toBe('test-provider')
expect(isProviderConfigAlias('test-alias')).toBe(true)
})
})
describe('Provider 创建和注册', () => {
it('能够创建 provider 实例', async () => {
const config: ProviderConfig = {
id: 'test-create-provider',
name: 'Test Create Provider',
creator: vi.fn(() => ({ name: 'test-created' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 创建 provider 实例
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
expect(provider).toBeDefined()
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
})
it('creates the Azure provider with chat as the default language model', async () => {
const provider = await createProvider('azure', { apiKey: 'test' })
expect(provider.languageModel('gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
})
it('creates the Azure responses provider with responses as the default language model', async () => {
const provider = await createProvider('azure-responses', { apiKey: 'test' })
expect(provider.languageModel('gpt-5.4')).toEqual({ mode: 'responses', modelId: 'gpt-5.4' })
})
it('能够注册 provider 到全局管理器', () => {
const mockProvider = { name: 'mock-provider' }
const config: ProviderConfig = {
id: 'test-register-provider',
name: 'Test Register Provider',
creator: vi.fn(() => mockProvider),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 注册 provider 到全局管理器
const success = registerProvider('test-register-provider', mockProvider)
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-register-provider')
expect(hasInitializedProviders()).toBe(true)
})
it('registers Azure chat lookups to the chat language model', async () => {
const success = await createAndRegisterProvider('azure', { apiKey: 'test' })
expect(success).toBe(true)
expect(getInitializedProviders()).toEqual(expect.arrayContaining(['azure', 'azure-chat']))
expect(getLanguageModel('azure|gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
expect(getLanguageModel('azure-chat|gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
})
it('能够一步完成创建和注册', async () => {
const config: ProviderConfig = {
id: 'test-create-and-register',
name: 'Test Create and Register',
creator: vi.fn(() => ({ name: 'test-both' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 一步完成创建和注册
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-create-and-register')
})
})
describe('Registry 管理', () => {
it('能够清理所有配置和注册的 providers', () => {
// 注册一些配置
registerProviderConfig({
id: 'temp-provider',
name: 'Temp Provider',
creator: vi.fn(() => ({ name: 'temp' })),
supportsImageGeneration: false
})
expect(hasProviderConfig('temp-provider')).toBe(true)
// 清理
cleanup()
expect(hasProviderConfig('temp-provider')).toBe(false)
// 但基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
})
it('能够单独清理已注册的 providers', () => {
// 清理所有 providers
clearAllProviders()
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('ProviderInitializationError 错误类工作正常', () => {
const error = new ProviderInitializationError('Test error', 'test-provider')
expect(error.message).toBe('Test error')
expect(error.providerId).toBe('test-provider')
expect(error.name).toBe('ProviderInitializationError')
})
})
describe('错误处理', () => {
it('优雅处理空配置', () => {
const success = registerProviderConfig(null as any)
expect(success).toBe(false)
})
it('优雅处理未定义配置', () => {
const success = registerProviderConfig(undefined as any)
expect(success).toBe(false)
})
it('处理空字符串 ID', () => {
const config = {
id: '',
name: 'Empty ID Provider',
creator: vi.fn(() => ({ name: 'empty' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(false)
})
it('处理创建不存在配置的 provider', async () => {
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
'ProviderConfig not found for id: non-existent-provider'
)
})
it('处理注册不存在配置的 provider', () => {
const mockProvider = { name: 'mock' }
const success = registerProvider('non-existent-provider', mockProvider)
expect(success).toBe(false)
})
it('处理获取不存在配置的情况', () => {
expect(getProviderConfig('non-existent')).toBeUndefined()
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
expect(hasProviderConfig('non-existent')).toBe(false)
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
})
it('处理批量注册时的部分失败', () => {
const mixedConfigs: ProviderConfig[] = [
{
id: 'valid-provider-1',
name: 'Valid Provider 1',
creator: vi.fn(() => ({ name: 'valid-1' })),
supportsImageGeneration: false
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any,
{
id: 'valid-provider-2',
name: 'Valid Provider 2',
creator: vi.fn(() => ({ name: 'valid-2' })),
supportsImageGeneration: true
}
]
const successCount = registerMultipleProviderConfigs(mixedConfigs)
expect(successCount).toBe(2) // 只有两个有效配置成功
expect(hasProviderConfig('valid-provider-1')).toBe(true)
expect(hasProviderConfig('valid-provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('处理动态导入失败的情况', async () => {
const config: ProviderConfig = {
id: 'import-test-provider',
name: 'Import Test Provider',
import: vi.fn().mockRejectedValue(new Error('Import failed')),
creatorFunctionName: 'createTest',
supportsImageGeneration: false
}
registerProviderConfig(config)
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
})
})
describe('集成测试', () => {
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
// 初始状态验证
const initialConfigs = getAllProviderConfigs()
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
expect(getInitializedProviders()).toEqual([])
// 注册多个带别名的 provider 配置
const configs: ProviderConfig[] = [
{
id: 'integration-provider-1',
name: 'Integration Provider 1',
creator: vi.fn(() => ({ name: 'integration-1' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'short-name-1']
},
{
id: 'integration-provider-2',
name: 'Integration Provider 2',
creator: vi.fn(() => ({ name: 'integration-2' })),
supportsImageGeneration: true,
aliases: ['alias-2', 'short-name-2']
}
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2)
// 验证配置注册成功
expect(hasProviderConfig('integration-provider-1')).toBe(true)
expect(hasProviderConfig('integration-provider-2')).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
// 验证别名映射
const aliases = getAllProviderConfigAliases()
expect(aliases['alias-1']).toBe('integration-provider-1')
expect(aliases['alias-2']).toBe('integration-provider-2')
// 创建和注册 providers
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
expect(success1).toBe(true)
expect(success2).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('integration-provider-1')
expect(registeredProviders).toContain('integration-provider-2')
expect(hasInitializedProviders()).toBe(true)
// 清理
cleanup()
// 验证清理后的状态
expect(getInitializedProviders()).toEqual([])
expect(hasProviderConfig('integration-provider-1')).toBe(false)
expect(hasProviderConfig('integration-provider-2')).toBe(false)
expect(getAllProviderConfigAliases()).toEqual({})
// 基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true)
})
it('正确处理动态导入配置的 provider', async () => {
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
const dynamicImportConfig: ProviderConfig = {
id: 'dynamic-import-provider',
name: 'Dynamic Import Provider',
import: vi.fn().mockResolvedValue(mockModule),
creatorFunctionName: 'createCustomProvider',
supportsImageGeneration: false
}
// 注册配置
const configSuccess = registerProviderConfig(dynamicImportConfig)
expect(configSuccess).toBe(true)
// 创建和注册 provider
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
expect(registerSuccess).toBe(true)
// 验证导入函数被调用
expect(dynamicImportConfig.import).toHaveBeenCalled()
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
// 验证注册成功
expect(getInitializedProviders()).toContain('dynamic-import-provider')
})
it('正确处理大量配置的注册和管理', () => {
const largeConfigList: ProviderConfig[] = []
// 生成50个配置
for (let i = 0; i < 50; i++) {
largeConfigList.push({
id: `bulk-provider-${i}`,
name: `Bulk Provider ${i}`,
creator: vi.fn(() => ({ name: `bulk-${i}` })),
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
aliases: [`alias-${i}`, `short-${i}`]
})
}
const successCount = registerMultipleProviderConfigs(largeConfigList)
expect(successCount).toBe(50)
// 验证所有配置都被正确注册
const allConfigs = getAllProviderConfigs()
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
// 验证别名数量
const aliases = getAllProviderConfigAliases()
const bulkAliases = Object.keys(aliases).filter(
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
)
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
// 随机验证几个配置
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
// 验证别名工作正常
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
expect(isProviderConfigAlias('short-30')).toBe(true)
// 清理能正确处理大量数据
cleanup()
const cleanupAliases = getAllProviderConfigAliases()
expect(
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
).toEqual([])
})
})
describe('边界测试', () => {
it('处理包含特殊字符的 provider IDs', () => {
const specialCharsConfigs: ProviderConfig[] = [
{
id: 'provider-with-dashes',
name: 'Provider With Dashes',
creator: vi.fn(() => ({ name: 'dashes' })),
supportsImageGeneration: false
},
{
id: 'provider_with_underscores',
name: 'Provider With Underscores',
creator: vi.fn(() => ({ name: 'underscores' })),
supportsImageGeneration: false
},
{
id: 'provider.with.dots',
name: 'Provider With Dots',
creator: vi.fn(() => ({ name: 'dots' })),
supportsImageGeneration: false
}
]
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
// 验证支持的特殊字符格式
if (hasProviderConfig('provider-with-dashes')) {
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
}
if (hasProviderConfig('provider_with_underscores')) {
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
}
})
it('处理空的批量注册', () => {
const successCount = registerMultipleProviderConfigs([])
expect(successCount).toBe(0)
// 确保没有额外的配置被添加
const configsBefore = getAllProviderConfigs().length
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
})
it('处理重复的配置注册', () => {
const config: ProviderConfig = {
id: 'duplicate-test-provider',
name: 'Duplicate Test Provider',
creator: vi.fn(() => ({ name: 'duplicate' })),
supportsImageGeneration: false
}
// 第一次注册成功
expect(registerProviderConfig(config)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 重复注册相同的配置(允许覆盖)
const updatedConfig: ProviderConfig = {
...config,
name: 'Updated Duplicate Test Provider'
}
expect(registerProviderConfig(updatedConfig)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 验证配置被更新
const retrievedConfig = getProviderConfig('duplicate-test-provider')
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
})
it('处理极长的 ID 和名称', () => {
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
const config: ProviderConfig = {
id: longId,
name: longName,
creator: vi.fn(() => ({ name: 'long-test' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig(longId)).toBe(true)
const retrievedConfig = getProviderConfig(longId)
expect(retrievedConfig?.name).toBe(longName)
})
it('处理大量别名的配置', () => {
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
const config: ProviderConfig = {
id: 'provider-with-many-aliases',
name: 'Provider With Many Aliases',
creator: vi.fn(() => ({ name: 'many-aliases' })),
supportsImageGeneration: false,
aliases: manyAliases
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
// 验证所有别名都能正确解析
manyAliases.forEach((alias) => {
expect(hasProviderConfigByAlias(alias)).toBe(true)
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
expect(isProviderConfigAlias(alias)).toBe(true)
})
})
})
})

View File

@@ -1,269 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import {
type BaseProviderId,
baseProviderIds,
baseProviderIdSchema,
baseProviders,
type CustomProviderId,
customProviderIdSchema,
providerConfigSchema,
type ProviderId,
providerIdSchema
} from '../schemas'
describe('Provider Schemas', () => {
describe('baseProviders', () => {
it('包含所有预期的基础 providers', () => {
expect(baseProviders).toBeDefined()
expect(Array.isArray(baseProviders)).toBe(true)
expect(baseProviders.length).toBeGreaterThan(0)
// These are the actual base providers defined in schemas.ts
const expectedIds = [
'openai',
'openai-chat',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'azure-responses',
'deepseek',
'openrouter',
'cherryin',
'cherryin-chat'
]
const actualIds = baseProviders.map((p) => p.id)
expectedIds.forEach((id) => {
expect(actualIds).toContain(id)
})
})
it('每个基础 provider 有必要的属性', () => {
baseProviders.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(provider).toHaveProperty('creator')
expect(provider).toHaveProperty('supportsImageGeneration')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
expect(typeof provider.creator).toBe('function')
expect(typeof provider.supportsImageGeneration).toBe('boolean')
})
})
it('provider ID 是唯一的', () => {
const ids = baseProviders.map((p) => p.id)
const uniqueIds = [...new Set(ids)]
expect(ids).toEqual(uniqueIds)
})
})
describe('baseProviderIds', () => {
it('正确提取所有基础 provider IDs', () => {
expect(baseProviderIds).toBeDefined()
expect(Array.isArray(baseProviderIds)).toBe(true)
expect(baseProviderIds.length).toBe(baseProviders.length)
baseProviders.forEach((provider) => {
expect(baseProviderIds).toContain(provider.id)
})
})
})
describe('baseProviderIdSchema', () => {
it('验证有效的基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的基础 provider IDs', () => {
const invalidIds = ['invalid', 'not-exists', '']
invalidIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('customProviderIdSchema', () => {
it('接受有效的自定义 provider IDs', () => {
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
validIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('拒绝空字符串', () => {
expect(customProviderIdSchema.safeParse('').success).toBe(false)
})
})
describe('providerIdSchema', () => {
it('接受基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('接受有效的自定义 provider IDs', () => {
const validCustomIds = ['custom-provider', 'my-ai-service']
validCustomIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = ['', undefined, null, 123]
invalidIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('providerConfigSchema', () => {
it('验证带有 creator 的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('验证带有 import 配置的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
import: vi.fn(),
creatorFunctionName: 'createCustom',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
const invalidConfig = {
id: 'invalid',
name: 'Invalid Provider',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('为 supportsImageGeneration 设置默认值', () => {
const config = {
id: 'test',
name: 'Test',
creator: vi.fn()
}
const result = providerConfigSchema.safeParse(config)
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.supportsImageGeneration).toBe(false)
}
})
it('拒绝使用基础 provider ID 的配置', () => {
const invalidConfig = {
id: 'openai', // 基础 provider ID
name: 'Should Fail',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('拒绝缺少必需字段的配置', () => {
const invalidConfigs = [
{ name: 'Missing ID', creator: vi.fn() },
{ id: 'missing-name', creator: vi.fn() },
{ id: '', name: 'Empty ID', creator: vi.fn() },
{ id: 'valid-custom', name: '', creator: vi.fn() }
]
invalidConfigs.forEach((config) => {
expect(providerConfigSchema.safeParse(config).success).toBe(false)
})
})
})
describe('Schema 验证功能', () => {
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
})
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
const customIds = ['custom-provider', 'my-service', 'company-llm']
customIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
// 拒绝基础 provider IDs
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
// 基础 IDs
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
// 自定义 IDs
const customIds = ['custom-provider', 'my-service']
customIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('providerConfigSchema 验证完整的 provider 配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
const invalidConfig = {
id: 'openai', // 不允许基础 provider ID
name: 'OpenAI',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
})
describe('类型推导', () => {
it('BaseProviderId 类型正确', () => {
const id: BaseProviderId = 'openai'
expect(baseProviderIds).toContain(id)
})
it('CustomProviderId 类型是字符串', () => {
const id: CustomProviderId = 'custom-provider'
expect(typeof id).toBe('string')
})
it('ProviderId 类型支持基础和自定义 IDs', () => {
const baseId: ProviderId = 'openai'
const customId: ProviderId = 'custom-provider'
expect(typeof baseId).toBe('string')
expect(typeof customId).toBe('string')
})
})
})

View File

@@ -0,0 +1,269 @@
/**
* Provider Types - Type-level Tests
* Tests type utilities and type inference for provider extensions
*/
import type { ProviderV3 } from '@ai-sdk/provider'
import { describe, expectTypeOf, it } from 'vitest'
import type { ProviderExtensionConfig } from '../core/ProviderExtension'
import type {
CoreProviderSettingsMap,
ExtensionConfigToIdResolutionMap,
ExtractExtensionIds,
ExtractProviderIds,
StringKeys,
UnionToIntersection
} from '../types'
describe('Type Utilities', () => {
describe('StringKeys<T>', () => {
it('should extract only string keys from object type', () => {
type TestObj = { foo: 1; bar: 2; 0: 3; 1: 4 }
type Result = StringKeys<TestObj>
expectTypeOf<Result>().toEqualTypeOf<'foo' | 'bar'>()
})
it('should return never for object with no string keys', () => {
type TestObj = { 0: 'a'; 1: 'b' }
type Result = StringKeys<TestObj>
expectTypeOf<Result>().toEqualTypeOf<never>()
})
it('should handle empty object', () => {
type Result = StringKeys<{}>
expectTypeOf<Result>().toEqualTypeOf<never>()
})
it('should preserve literal string keys', () => {
type TestObj = { openai: 1; anthropic: 2; google: 3 }
type Result = StringKeys<TestObj>
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'anthropic' | 'google'>()
})
})
describe('UnionToIntersection<U>', () => {
it('should convert union to intersection', () => {
type Union = { a: 1 } | { b: 2 }
type Result = UnionToIntersection<Union>
expectTypeOf<Result>().toEqualTypeOf<{ a: 1 } & { b: 2 }>()
})
it('should handle single type', () => {
type Single = { a: 1 }
type Result = UnionToIntersection<Single>
expectTypeOf<Result>().toEqualTypeOf<{ a: 1 }>()
})
})
describe('ExtractProviderIds<TConfig>', () => {
it('should extract base name', () => {
type Config = { name: 'openai' }
type Result = ExtractProviderIds<Config>
expectTypeOf<Result>().toEqualTypeOf<'openai'>()
})
it('should extract name and aliases', () => {
type Config = { name: 'anthropic'; aliases: readonly ['claude'] }
type Result = ExtractProviderIds<Config>
expectTypeOf<Result>().toEqualTypeOf<'anthropic' | 'claude'>()
})
it('should extract name and variants', () => {
type Config = { name: 'openai'; variants: readonly [{ suffix: 'chat' }] }
type Result = ExtractProviderIds<Config>
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'openai-chat'>()
})
it('should extract name, aliases, and variants', () => {
type Config = {
name: 'azure'
aliases: readonly ['azure-openai']
variants: readonly [{ suffix: 'responses' }]
}
type Result = ExtractProviderIds<Config>
expectTypeOf<Result>().toEqualTypeOf<'azure' | 'azure-openai' | 'azure-responses'>()
})
it('should handle multiple variants', () => {
type Config = {
name: 'openai'
variants: readonly [{ suffix: 'chat' }, { suffix: 'responses' }]
}
type Result = ExtractProviderIds<Config>
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'openai-chat' | 'openai-responses'>()
})
})
describe('ExtensionConfigToIdResolutionMap<TConfig>', () => {
it('should map base name to itself', () => {
type Config = { name: 'openai' }
type Result = ExtensionConfigToIdResolutionMap<Config>
expectTypeOf<Result>().toEqualTypeOf<{ readonly openai: 'openai' }>()
})
it('should map aliases to base name', () => {
type Config = { name: 'anthropic'; aliases: readonly ['claude'] }
type Result = ExtensionConfigToIdResolutionMap<Config>
expectTypeOf<Result>().toEqualTypeOf<{
readonly anthropic: 'anthropic'
readonly claude: 'anthropic'
}>()
})
it('should map variants to themselves (self-referential)', () => {
type Config = { name: 'azure'; variants: readonly [{ suffix: 'responses' }] }
type Result = ExtensionConfigToIdResolutionMap<Config>
expectTypeOf<Result>().toEqualTypeOf<{
readonly azure: 'azure'
readonly 'azure-responses': 'azure-responses'
}>()
})
it('should handle combined aliases and variants correctly', () => {
type Config = {
name: 'azure'
aliases: readonly ['azure-openai']
variants: readonly [{ suffix: 'responses' }]
}
type Result = ExtensionConfigToIdResolutionMap<Config>
expectTypeOf<Result>().toEqualTypeOf<{
readonly azure: 'azure'
readonly 'azure-openai': 'azure'
readonly 'azure-responses': 'azure-responses'
}>()
})
})
describe('ExtractExtensionIds<T>', () => {
it('should extract IDs from extension with config property', () => {
type MockExtension = {
config: { name: 'test'; aliases: readonly ['test-alias'] }
}
type Result = ExtractExtensionIds<MockExtension>
expectTypeOf<Result>().toEqualTypeOf<'test' | 'test-alias'>()
})
})
describe('ExtensionToSettingsMap<T>', () => {
it('should map provider IDs to settings type', () => {
type MockSettings = { apiKey: string }
type MockConfig = { name: 'mock' }
// This tests the concept - actual implementation depends on ProviderExtension structure
type Result = { [K in ExtractProviderIds<MockConfig>]: MockSettings }
expectTypeOf<Result>().toEqualTypeOf<{ mock: MockSettings }>()
})
})
describe('CoreProviderSettingsMap', () => {
it('should include openai provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openai')
})
it('should include anthropic provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('anthropic')
})
it('should include google provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('google')
})
it('should include azure provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('azure')
})
it('should include xai provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('xai')
})
it('should include deepseek provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('deepseek')
})
it('should include openrouter provider', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openrouter')
})
it('should include aliases like claude', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('claude')
})
it('should include variants like openai-chat', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openai-chat')
})
it('should include variants like azure-responses', () => {
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('azure-responses')
})
})
})
describe('ProviderExtensionConfig Type Constraints', () => {
it('should accept valid minimal config', () => {
type ValidConfig = ProviderExtensionConfig<{ apiKey: string }, ProviderV3, 'test'>
// Should compile without errors
const config: ValidConfig = {
name: 'test',
create: () => ({}) as ProviderV3
}
expectTypeOf(config.name).toEqualTypeOf<'test'>()
})
it('should accept config with aliases', () => {
type ConfigWithAliases = {
name: 'anthropic'
aliases: readonly ['claude']
create: () => ProviderV3
}
const config: ConfigWithAliases = {
name: 'anthropic',
aliases: ['claude'] as const,
create: () => ({}) as ProviderV3
}
expectTypeOf(config.aliases).toEqualTypeOf<readonly ['claude']>()
})
it('should accept config with variants', () => {
type ConfigWithVariants = {
name: 'openai'
variants: readonly [{ suffix: 'chat'; name: string; transform: (p: ProviderV3) => ProviderV3 }]
create: () => ProviderV3
}
const config: ConfigWithVariants = {
name: 'openai',
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
transform: (p) => p
}
] as const,
create: () => ({}) as ProviderV3
}
expectTypeOf(config.variants[0].suffix).toEqualTypeOf<'chat'>()
})
})

View File

@@ -0,0 +1,513 @@
/**
* Extension Registry
* 管理所有 Provider Extensions 的注册、查询和实例化
*/
import type { ProviderV3 } from '@ai-sdk/provider'
import type { CoreProviderSettingsMap, RegisteredProviderId, ToolCapability, ToolFactory } from '../index'
import { type ProviderExtension } from './ProviderExtension'
import { ProviderCreationError } from './utils'
/**
* Provider Extension 注册表
*
* 职责:
* - 注册和管理 Provider Extensions
* - 根据 ID 查找对应的 Extension
* - 创建并注册 provider 实例(包括变体)
*
* @example
* ```typescript
* import { extensionRegistry } from '@cherrystudio/ai-core/provider'
* import { OpenAIExtension } from './extensions/openai'
*
* // 注册 extension
* extensionRegistry.register(OpenAIExtension)
*
* // 批量注册
* extensionRegistry.registerAll([
* OpenAIExtension,
* AzureExtension,
* AnthropicExtension
* ])
*
* // 创建并注册 provider 实例
* await extensionRegistry.createAndRegisterProvider('openai', {
* apiKey: 'sk-xxx'
* })
* ```
*/
export class ExtensionRegistry {
/** Extension 存储: name -> Extension */
private extensions: Map<string, ProviderExtension<any, any, any>> = new Map()
/** 别名映射: alias -> name */
private aliasMap: Map<string, string> = new Map()
/**
* 注册单个 Extension
* 支持链式调用
*/
register(extension: ProviderExtension<any, any, any>): this {
const { name, aliases, variants } = extension.config
// Idempotent: skip if already registered (supports HMR / re-import)
if (this.extensions.has(name)) {
return this
}
this.extensions.set(name, extension)
if (aliases) {
for (const alias of aliases) {
if (this.aliasMap.has(alias)) {
throw new Error(`Provider alias "${alias}" is already registered for "${this.aliasMap.get(alias)}"`)
}
this.aliasMap.set(alias, name)
}
}
if (variants) {
for (const variant of variants) {
const variantId = `${name}-${variant.suffix}`
if (this.aliasMap.has(variantId)) {
throw new Error(
`Provider variant ID "${variantId}" is already registered for "${this.aliasMap.get(variantId)}"`
)
}
this.aliasMap.set(variantId, name)
}
}
return this
}
/**
* 批量注册 Extensions
* 支持 readonly 数组(用于 as const 数组)
*/
registerAll(extensions: readonly ProviderExtension<any, any, any>[]): this {
for (const ext of extensions) {
this.register(ext)
}
return this
}
/**
* 取消注册 Extension
*/
unregister(name: string): boolean {
const extension = this.extensions.get(name)
if (!extension) {
return false
}
extension.clearCache()
this.extensions.delete(name)
if (extension.config.aliases) {
for (const alias of extension.config.aliases) {
this.aliasMap.delete(alias)
}
}
if (extension.config.variants) {
for (const variant of extension.config.variants) {
this.aliasMap.delete(`${name}-${variant.suffix}`)
}
}
return true
}
/**
* 获取 Extension支持别名
*/
get(id: string): ProviderExtension<any, any, any> | undefined {
if (this.extensions.has(id)) {
return this.extensions.get(id)
}
const realName = this.aliasMap.get(id)
if (realName) {
return this.extensions.get(realName)
}
return undefined
}
/**
* 获取 Extension
*
* @param id - Provider ID必须是 RegisteredProviderId
* @returns Extension 或 undefined
*
* @example
* ```typescript
* const ext = extensionRegistry.getTyped('openai')
* if (ext) {
* const provider = await ext.createProvider({
* apiKey: 'sk-...'
* })
* }
* ```
*/
getTyped<T extends RegisteredProviderId>(id: T): ProviderExtension<any, any, any> | undefined {
return this.get(id)
}
/**
* 检查 Extension 是否已注册
*/
has(id: string): boolean {
return this.extensions.has(id) || this.aliasMap.has(id)
}
/**
* 获取所有已注册的 Extension
*/
getAll(): ProviderExtension<any, any, any>[] {
return Array.from(this.extensions.values())
}
/**
* 获取所有已注册的 provider IDs包含变体
* 返回类型安全的 RegisteredProviderId 数组,自动去重
*/
getAllProviderIds(): RegisteredProviderId[] {
const ids = new Set<string>()
for (const extension of this.extensions.values()) {
for (const id of extension.getProviderIds()) {
ids.add(id)
}
}
return Array.from(ids) as RegisteredProviderId[]
}
/**
* 根据 base ID + mode 解析到完整的 provider ID
*
* 支持别名:如果 baseId 是别名,会先解析到规范 ID
*
* @param baseId - 基础 provider ID可以是别名
* @param mode - 模式(如 'chat', 'responses'
* @returns 完整的 provider ID如果无法解析则返回 null
*
* @example
* ```typescript
* resolveProviderIdWithMode('openai', 'chat') // → 'openai-chat'
* resolveProviderIdWithMode('azure', 'responses') // → 'azure-responses'
* resolveProviderIdWithMode('gemini', 'chat') // → null (google 没有 chat 变体)
* resolveProviderIdWithMode('openai') // → 'openai' (没有 mode)
* ```
*/
resolveProviderIdWithMode(baseId: string, mode?: string): string | null {
// 如果没有 mode直接返回解析后的 ID
if (!mode) {
const extension = this.get(baseId)
return extension ? extension.config.name : null
}
// 获取 extension支持别名
const extension = this.get(baseId)
if (!extension) {
return null
}
// 检查是否有对应的变体
if (!extension.config.variants) {
return null
}
// 查找匹配的变体
const variant = extension.config.variants.find((v: { suffix: string }) => v.suffix === mode)
if (!variant) {
return null
}
// 返回变体 ID: ${name}-${suffix}
return `${extension.config.name}-${variant.suffix}`
}
/**
* 反向解析:从完整 ID 提取 base ID 和 mode
*
* 遍历所有 extensions 的变体,匹配 `${name}-${suffix}` 模式
*
* @param providerId - 完整的 provider ID
* @returns 解析结果,如果无法解析返回 null
*
* @example
* ```typescript
* parseProviderId('openai-chat') // → { baseId: 'openai', mode: 'chat', isVariant: true }
* parseProviderId('azure-responses') // → { baseId: 'azure', mode: 'responses', isVariant: true }
* parseProviderId('openai') // → { baseId: 'openai', isVariant: false }
* parseProviderId('oai') // → { baseId: 'openai', isVariant: false } (别名)
* parseProviderId('unknown') // → null
* ```
*/
parseProviderId(providerId: string): { baseId: RegisteredProviderId; mode?: string; isVariant: boolean } | null {
// 先遍历所有 extensions查找匹配的变体优先于别名检查
for (const ext of this.extensions.values()) {
if (!ext.config.variants) {
continue
}
// 检查每个变体
for (const variant of ext.config.variants) {
const variantId = `${ext.config.name}-${variant.suffix}`
if (variantId === providerId) {
return {
baseId: ext.config.name as RegisteredProviderId,
mode: variant.suffix,
isVariant: true
}
}
}
}
// 再检查是否是已注册的 extension直接或通过别名
const extension = this.get(providerId)
if (extension) {
// 是基础 ID 或别名,不是变体
return {
baseId: extension.config.name as RegisteredProviderId,
isVariant: false
}
}
// 无法解析
return null
}
/**
* 检查是否为变体 ID
*
* @param id - Provider ID
* @returns 如果是变体 ID 返回 true
*
* @example
* ```typescript
* isVariant('openai-chat') // → true
* isVariant('azure-responses') // → true
* isVariant('openai') // → false
* isVariant('unknown') // → false
* ```
*/
isVariant(id: string): boolean {
const parsed = this.parseProviderId(id)
return parsed?.isVariant ?? false
}
/**
* 获取基础 provider ID
*
* 对于变体ID返回其基础provider ID
* 对于基础ID或别名返回规范的provider ID
* 对于未知ID返回null
*
* @param id - Provider ID可以是基础ID、变体ID或别名
* @returns 基础 provider ID如果无法解析则返回 null
*
* @example
* ```typescript
* getBaseProviderId('openai-chat') // → 'openai' (变体)
* getBaseProviderId('azure-responses') // → 'azure' (变体)
* getBaseProviderId('openai') // → 'openai' (基础ID)
* getBaseProviderId('oai') // → 'openai' (别名)
* getBaseProviderId('unknown') // → null
* ```
*/
getBaseProviderId(id: string): RegisteredProviderId | null {
const parsed = this.parseProviderId(id)
return parsed?.baseId ?? null
}
/**
* 获取变体的模式/后缀
*
* @param variantId - 变体 ID
* @returns 模式/后缀,如果不是变体则返回 null
*
* @example
* ```typescript
* getVariantMode('openai-chat') // → 'chat'
* getVariantMode('azure-responses') // → 'responses'
* getVariantMode('openai') // → null (不是变体)
* getVariantMode('unknown') // → null
* ```
*/
getVariantMode(variantId: string): string | null {
const parsed = this.parseProviderId(variantId)
return parsed?.mode ?? null
}
/** 获取 variant 的 resolveModel 函数(类型安全在 extension 声明处保证) */
getModelResolver(providerId: string): ((provider: ProviderV3, modelId: string) => any) | undefined {
const parsed = this.parseProviderId(providerId)
if (!parsed) return undefined
const extension = this.get(parsed.baseId)
if (!extension) return undefined
// Variant resolveModel类型安全在 extension 声明处校验)
if (parsed.isVariant && parsed.mode) {
const variant = extension.getVariant(parsed.mode)
if (variant?.resolveModel) return variant.resolveModel
}
return undefined
}
/**
* 获取某个基础 provider 的所有变体 IDs
*
* @param baseId - 基础 provider ID可以是别名
* @returns 变体 ID 数组,如果没有变体则返回空数组
*
* @example
* ```typescript
* getVariants('openai') // → ['openai-chat']
* getVariants('azure') // → ['azure-responses']
* getVariants('google') // → ['google-chat']
* getVariants('xai') // → [] (没有变体)
* getVariants('unknown') // → [] (未注册)
* ```
*/
getVariants(baseId: string): string[] {
const extension = this.get(baseId)
if (!extension?.config.variants) {
return []
}
return extension.config.variants.map((v: { suffix: string }) => `${extension.config.name}-${v.suffix}`)
}
/** 获取指定 provider 的工具工厂(变体优先,回退到 base */
getToolFactory(providerId: string, capability: ToolCapability): ToolFactory | undefined {
const parsed = this.parseProviderId(providerId)
if (!parsed) return undefined
const { baseId, mode, isVariant } = parsed
const extension = this.get(baseId)
if (!extension) return undefined
// For variants, check variant-level toolFactories first
if (isVariant && mode) {
const variant = extension.getVariant(mode)
if (variant?.toolFactories?.[capability]) {
return variant.toolFactories[capability]
}
}
// Fall back to base extension's toolFactories
return extension.config.toolFactories?.[capability]
}
/**
* 解析工具能力:返回 factory + provider 实例
*
* 1. Direct — provider 自己有 toolFactories
* 2. Aggregator fallback — 从 model.provider 段解析(如 "aihubmix.google" → google extension
*/
async resolveToolCapability(
providerId: string,
capability: ToolCapability,
modelProvider?: string
): Promise<{ factory: ToolFactory; provider: ProviderV3 } | undefined> {
// 1. Direct: provider 自己有 toolFactories
const directFactory = this.getToolFactory(providerId, capability)
if (directFactory) {
const provider = await this.getToolProvider(providerId)
if (provider) return { factory: directFactory, provider }
}
// 2. Aggregator fallback: 从 model.provider 段解析真实 provider
// e.g., "aihubmix.google" → try "google" → found via google extension
// e.g., "cherryin.gemini" → try "gemini" → found via alias → google extension
if (typeof modelProvider === 'string') {
const segments = modelProvider.split('.')
for (let i = segments.length - 1; i >= 0; i--) {
const factory = this.getToolFactory(segments[i], capability)
if (factory) {
const provider = await this.getToolProvider(segments[i])
if (provider) return { factory, provider }
}
}
}
return undefined
}
/** Get base provider for .tools extraction (cached or dummy instance) */
private async getToolProvider(providerId: string): Promise<ProviderV3 | undefined> {
const parsed = this.parseProviderId(providerId)
if (!parsed) return undefined
const extension = this.get(parsed.baseId)
if (!extension) return undefined
const cached = extension.getCachedProvider()
if (cached) return cached
try {
return await extension.createProvider({ apiKey: '_tool_descriptor' })
} catch {
return undefined
}
}
/**
* 清空所有注册
*/
clear(): void {
this.extensions.clear()
this.aliasMap.clear()
}
/**
* 创建 provider 实例
*
* 支持两种调用方式:
* 1. 类型安全版本 - 使用已注册的 provider ID获得完整的类型推导
* 2. 动态版本 - 使用任意字符串 ID用于测试或动态注册的 provider
*
* @param id - Provider ID
* @param settings - Provider 配置
* @returns Provider 实例
*/
async createProvider<T extends RegisteredProviderId>(id: T, settings: CoreProviderSettingsMap[T]): Promise<ProviderV3>
async createProvider(id: string, settings?: unknown): Promise<ProviderV3>
async createProvider(id: string, settings?: unknown): Promise<ProviderV3> {
const parsed = this.parseProviderId(id)
if (!parsed) {
throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`)
}
const { baseId, mode: variantSuffix } = parsed
const extension = this.get(baseId)
if (!extension) {
throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`)
}
try {
return await extension.createProvider(settings, variantSuffix)
} catch (error) {
throw new ProviderCreationError(
`Failed to create provider "${id}"`,
id,
error instanceof Error ? error : new Error(String(error))
)
}
}
}
/**
* 全局 Extension Registry 实例
* 单例模式,确保整个应用只有一个注册表
*/
export const extensionRegistry = new ExtensionRegistry()

View File

@@ -0,0 +1,344 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import { LRUCache } from 'lru-cache'
import { deepMergeObjects } from '../../utils'
import type { ProviderVariant, ToolFactoryMap } from '../types'
export type ProviderCreatorFunction<TSettings = any> = (settings?: TSettings) => ProviderV3 | Promise<ProviderV3>
/**
* Provider 模块类型
* 动态导入的模块应该包含至少一个创建函数
* 允许 default 导出和其他属性
*/
export type ProviderModule<TSettings = any> = Record<string, any> & {
[K: string]: ProviderCreatorFunction<TSettings> | any
}
/**
* Provider Extension 配置基础接口
*
* @typeParam TSettings - Provider 配置类型
* @typeParam TProvider - 实际 provider 类型(用于 variants
* @typeParam TName - Provider 名称类型(用于字面量推导)
*/
interface ProviderExtensionConfigBase<
TSettings = any,
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> {
/** Provider 唯一标识 */
name: TName
/** 别名列表(可选) */
aliases?: readonly string[]
/** 默认配置选项 */
defaultOptions?: Partial<TSettings>
/** 是否支持图像生成 */
supportsImageGeneration?: boolean
/**
* Provider 变体配置
* 用于注册同一 provider 的不同模式
*/
variants?: readonly ProviderVariant<TSettings, TProvider>[]
/**
* Tool factory 映射
* 声明该 provider 支持的工具能力(如 webSearch
* 工具工厂从 provider 实例的 .tools 属性提取
*/
toolFactories?: ToolFactoryMap<TProvider>
}
/**
* Provider Extension 配置接口 - 使用 create 函数
*/
interface ProviderExtensionConfigWithCreate<
TSettings = any,
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> extends ProviderExtensionConfigBase<TSettings, TProvider, TName> {
create: ProviderCreatorFunction<TSettings>
import?: never
creatorFunctionName?: never
}
/**
* Provider Extension 配置接口 - 使用动态导入
*/
interface ProviderExtensionConfigWithImport<
TSettings = any,
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> extends ProviderExtensionConfigBase<TSettings, TProvider, TName> {
create?: never
import: () => Promise<ProviderModule<TSettings>>
creatorFunctionName: string
}
/**
* Provider Extension 配置接口
* 使用联合类型确保 create 和 import 互斥
*
* @typeParam TSettings - Provider 配置类型
* @typeParam TProvider - 实际 provider 类型(用于 variants
* @typeParam TName - Provider 名称类型(用于字面量推导)
*/
export type ProviderExtensionConfig<
TSettings = any,
TProvider extends ProviderV3 = ProviderV3,
TName extends string = string
> =
| ProviderExtensionConfigWithCreate<TSettings, TProvider, TName>
| ProviderExtensionConfigWithImport<TSettings, TProvider, TName>
/**
* Provider Extension 类
*
* @typeParam TSettings - Provider 配置类型
* @typeParam TProvider - 实际 provider 类型(用于 variants
* @typeParam TConfig - 配置对象类型(幻影类型参数,用于自动推导 Provider IDs
*/
export class ProviderExtension<
TSettings = any,
TProvider extends ProviderV3 = ProviderV3,
TConfig extends ProviderExtensionConfig<TSettings, TProvider, string> = ProviderExtensionConfig<
TSettings,
TProvider,
string
>
> {
/** Provider 实例缓存 - 按 settings hash 存储LRU 自动清理 */
private instances: LRUCache<string, TProvider>
/** In-flight promise map - 防止并发创建相同 settings 的 provider */
private pendingCreations: Map<string, Promise<TProvider>> = new Map()
constructor(public readonly config: TConfig) {
if (!config.name) {
throw new Error('ProviderExtension: name is required')
}
this.instances = new LRUCache<string, TProvider>({
max: 10,
updateAgeOnGet: true
})
}
/**
* 静态工厂方法 - 创建 Provider Extension
*/
static create<
const TConfig extends ProviderExtensionConfig<any, any, string>,
TSettings = TConfig extends ProviderExtensionConfig<infer S, any, any> ? S : any,
TProvider extends ProviderV3 = TConfig extends ProviderExtensionConfig<any, infer P, any> ? P : ProviderV3
>(config: TConfig | (() => TConfig)): ProviderExtension<TSettings, TProvider, TConfig>
static create(config: any): ProviderExtension<any, any, any> {
const resolvedConfig = typeof config === 'function' ? config() : config
return new ProviderExtension(resolvedConfig)
}
/**
* Options getter - 只读配置
*/
get options(): Readonly<Partial<TSettings>> {
return Object.freeze({ ...this.config.defaultOptions })
}
/**
* 计算 settings 的稳定 hash
*/
private computeHash(settings?: TSettings, variantSuffix?: string): string {
const baseKey = (() => {
if (settings === undefined || settings === null) {
return 'default'
}
const stableStringify = (obj: any): string => {
if (obj === null || obj === undefined) return 'null'
if (typeof obj === 'function') return '"[function]"'
if (typeof obj !== 'object') return JSON.stringify(obj)
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
const keys = Object.keys(obj).sort()
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
return `{${pairs.join(',')}}`
}
return stableStringify(settings)
})()
return variantSuffix ? `${baseKey}:${variantSuffix}` : baseKey
}
/**
* 创建 Provider 实例
* 相同 settings 会复用实例,不同 settings 会创建新实例
*/
async createProvider(settings?: TSettings, variantSuffix?: string): Promise<TProvider> {
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)
if (!variant) {
throw new Error(
`ProviderExtension "${this.config.name}": variant "${variantSuffix}" not found. ` +
`Available variants: ${this.config.variants?.map((v) => v.suffix).join(', ') || 'none'}`
)
}
}
const mergedSettings = deepMergeObjects(this.config.defaultOptions || {}, settings || {}) as TSettings
const hash = this.computeHash(mergedSettings, variantSuffix)
const cachedInstance = this.instances.get(hash)
if (cachedInstance) {
return cachedInstance
}
const pending = this.pendingCreations.get(hash)
if (pending) {
return pending
}
const creationPromise = this._doCreateProvider(mergedSettings, variantSuffix, hash)
this.pendingCreations.set(hash, creationPromise)
try {
return await creationPromise
} finally {
this.pendingCreations.delete(hash)
}
}
/**
* 获取基础 provider 实例(无变体转换)
* 用于访问 provider 的 .tools 属性
*/
async getBaseProvider(settings?: TSettings): Promise<TProvider> {
return this.createProvider(settings)
}
private async _doCreateProvider(
mergedSettings: TSettings,
variantSuffix: string | undefined,
hash: string
): Promise<TProvider> {
let baseProvider: ProviderV3
if (this.config.create) {
baseProvider = await Promise.resolve(this.config.create(mergedSettings))
} else if (this.config.import && this.config.creatorFunctionName) {
const module = await this.config.import()
const creatorFn = module[this.config.creatorFunctionName]
if (!creatorFn || typeof creatorFn !== 'function') {
throw new Error(
`ProviderExtension "${this.config.name}": creatorFunctionName "${this.config.creatorFunctionName}" not found in imported module`
)
}
baseProvider = await Promise.resolve(creatorFn(mergedSettings))
} else {
throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`)
}
let finalProvider: TProvider
if (variantSuffix) {
const variant = this.getVariant(variantSuffix)!
if (variant.transform) {
const baseHash = this.computeHash(mergedSettings)
if (!this.instances.has(baseHash)) {
this.instances.set(baseHash, baseProvider as TProvider)
}
finalProvider = (await Promise.resolve(
variant.transform(baseProvider as TProvider, mergedSettings)
)) as TProvider
} else {
finalProvider = baseProvider as TProvider
}
} else {
finalProvider = baseProvider as TProvider
}
this.instances.set(hash, finalProvider)
return finalProvider
}
/**
* 配置 provider链式调用
* 返回一个新的 Extension 实例,不修改原实例
*/
configure(settings: Partial<TSettings>): ProviderExtension<TSettings, TProvider> {
return new ProviderExtension({
...this.config,
defaultOptions: deepMergeObjects(this.config.defaultOptions || ({} as any), settings)
})
}
/**
* 获取所有 provider IDs包含变体和别名
*/
getProviderIds(): string[] {
const ids = [this.config.name, ...(this.config.aliases || [])]
if (this.config.variants) {
for (const variant of this.config.variants) {
ids.push(`${this.config.name}-${variant.suffix}`)
}
}
return ids
}
/**
* 检查给定 ID 是否属于此 Extension
*/
hasProviderId(id: string): boolean {
return this.getProviderIds().includes(id)
}
/**
* 获取变体配置
*/
getVariant(suffix: string): ProviderVariant<TSettings, TProvider> | undefined {
return this.config.variants?.find((v) => v.suffix === suffix)
}
/**
* 清除所有缓存的 Provider 实例
*/
clearCache(): void {
this.instances.clear()
this.pendingCreations.clear()
}
/**
* 获取已缓存的 provider 实例(如果存在)
*/
getCachedProvider(): TProvider | undefined {
for (const [key, value] of this.instances.entries()) {
if (!key.includes(':')) return value
}
for (const [, value] of this.instances.entries()) {
return value
}
return undefined
}
/**
* 获取缓存统计信息
*/
getCacheStats(): { cachedInstances: number } {
return {
cachedInstances: this.instances.size
}
}
}

View File

@@ -0,0 +1,312 @@
/**
* Provider 初始化器
* 负责根据配置创建 providers 并注册到全局管理器
* 使用新的 Extension 系统
*/
import type { AnthropicProvider, AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { createAnthropic } from '@ai-sdk/anthropic'
import type { AzureOpenAIProvider, AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createAzure } from '@ai-sdk/azure'
import type { DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { createDeepSeek } from '@ai-sdk/deepseek'
import type { GoogleGenerativeAIProvider, GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import type { OpenAIProvider, OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAI } from '@ai-sdk/openai'
import type { OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import type { ProviderV3 } from '@ai-sdk/provider'
import type { XaiProvider, XaiProviderSettings } from '@ai-sdk/xai'
import { createXai } from '@ai-sdk/xai'
import type { CherryInProvider, CherryInProviderSettings } from '@cherrystudio/ai-sdk-provider'
import { createCherryIn } from '@cherrystudio/ai-sdk-provider'
import type { OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
import { customProvider } from 'ai'
import type { OpenRouterSearchConfig } from '../../plugins/built-in/webSearchPlugin'
import type { ExtensionConfigToIdResolutionMap, ExtractExtensionIds, UnionToIntersection } from '../types'
import { extensionRegistry } from './ExtensionRegistry'
import type { ProviderExtensionConfig } from './ProviderExtension'
import { ProviderExtension } from './ProviderExtension'
// ==================== Core Extensions ====================
const AnthropicExtension = ProviderExtension.create({
name: 'anthropic',
aliases: ['claude'] as const,
supportsImageGeneration: false,
create: createAnthropic,
toolFactories: {
webSearch:
(provider) => (config: NonNullable<Parameters<AnthropicProvider['tools']['webSearch_20250305']>[0]>) => ({
tools: { webSearch: provider.tools.webSearch_20250305(config) }
}),
urlContext:
(provider) => (config: NonNullable<Parameters<AnthropicProvider['tools']['webFetch_20260209']>[0]>) => ({
tools: { webSearch: provider.tools.webFetch_20260209(config) }
})
}
} as const satisfies ProviderExtensionConfig<AnthropicProviderSettings, AnthropicProvider, 'anthropic'>)
/**
* Azure Extension
*/
const AzureExtension = ProviderExtension.create({
name: 'azure',
aliases: ['azure-openai'] as const,
supportsImageGeneration: true,
create: (settings) => {
const provider = createAzure(settings)
// Default to chat mode (AI SDK defaults to responses API)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
toolFactories: {
webSearch:
(provider: AzureOpenAIProvider) =>
(config: NonNullable<Parameters<AzureOpenAIProvider['tools']['webSearchPreview']>[0]>) => ({
tools: { webSearch: provider.tools.webSearchPreview(config) }
})
},
variants: [
{
suffix: 'responses',
name: 'Azure OpenAI Responses',
transform: (provider) =>
customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
},
// Azure 上的 Claude 模型走 Anthropic SDK
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
{
suffix: 'anthropic',
name: 'Azure Anthropic',
transform: (_provider, settings) =>
createAnthropic({
baseURL: (settings?.baseURL ?? '') + '/anthropic/v1',
apiKey: settings?.apiKey ?? '',
headers: settings?.headers
})
}
] as const
} as const satisfies ProviderExtensionConfig<AzureOpenAIProviderSettings, AzureOpenAIProvider, 'azure'>)
const CherryInExtension = ProviderExtension.create({
name: 'cherryin',
supportsImageGeneration: true,
create: createCherryIn,
variants: [
{
suffix: 'chat',
name: 'CherryIN Chat',
transform: (provider) =>
customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
}
] as const
} as const satisfies ProviderExtensionConfig<CherryInProviderSettings, CherryInProvider, 'cherryin'>)
const DeepSeekExtension = ProviderExtension.create({
name: 'deepseek',
supportsImageGeneration: false,
create: createDeepSeek
} as const satisfies ProviderExtensionConfig<DeepSeekProviderSettings, ProviderV3, 'deepseek'>)
const GoogleExtension = ProviderExtension.create({
name: 'google',
aliases: ['google-ai', 'gemini', 'google-gemini'] as const,
supportsImageGeneration: true,
create: createGoogleGenerativeAI,
toolFactories: {
webSearch:
(provider: GoogleGenerativeAIProvider) =>
(config: NonNullable<Parameters<GoogleGenerativeAIProvider['tools']['googleSearch']>[0]>) => ({
tools: { webSearch: provider.tools.googleSearch(config) }
}),
urlContext: (provider) => (config) => ({
tools: {
urlContext: provider.tools.urlContext(config)
}
})
}
} as const satisfies ProviderExtensionConfig<GoogleGenerativeAIProviderSettings, GoogleGenerativeAIProvider, 'google'>)
const OpenAICompatibleExtension = ProviderExtension.create({
name: 'openai-compatible',
supportsImageGeneration: true,
create: (settings) => {
if (!settings) {
throw new Error('OpenAI Compatible provider requires settings')
}
return createOpenAICompatible(settings)
}
} as const satisfies ProviderExtensionConfig<OpenAICompatibleProviderSettings, ProviderV3, 'openai-compatible'>)
const OpenAIExtension = ProviderExtension.create({
name: 'openai',
aliases: ['openai-response'] as const,
supportsImageGeneration: true,
create: createOpenAI,
toolFactories: {
webSearch:
(provider: OpenAIProvider) => (config: NonNullable<Parameters<OpenAIProvider['tools']['webSearch']>[0]>) => ({
tools: { webSearch: provider.tools.webSearch(config) }
})
},
variants: [
{
suffix: 'chat',
name: 'OpenAI Chat',
resolveModel: (provider: OpenAIProvider, modelId: string) => provider.chat(modelId),
toolFactories: {
webSearch:
(provider: OpenAIProvider) =>
(config: NonNullable<Parameters<OpenAIProvider['tools']['webSearchPreview']>[0]>) => ({
tools: { webSearch: provider.tools.webSearchPreview(config) }
})
}
}
] as const
} as const satisfies ProviderExtensionConfig<OpenAIProviderSettings, OpenAIProvider, 'openai'>)
const OpenRouterExtension = ProviderExtension.create({
name: 'openrouter',
// TODO: 实现注册后修改拓展配置
aliases: ['tokenflux'] as const,
supportsImageGeneration: true,
create: createOpenRouter,
toolFactories: {
webSearch: () => (config: OpenRouterSearchConfig) => ({
providerOptions: { openrouter: config }
})
}
} as const satisfies ProviderExtensionConfig<OpenRouterProviderSettings, ProviderV3, 'openrouter'>)
const XaiExtension = ProviderExtension.create({
name: 'xai',
aliases: ['grok'] as const,
supportsImageGeneration: true,
create: createXai,
variants: [
{
suffix: 'responses',
name: 'xAI Responses',
resolveModel: (provider: XaiProvider, modelId: string) => provider.responses(modelId),
toolFactories: {
webSearch:
(provider: XaiProvider) =>
(config: {
webSearch?: Parameters<XaiProvider['tools']['webSearch']>[0]
xSearch?: Parameters<XaiProvider['tools']['xSearch']>[0]
}) => ({
tools: {
webSearch: provider.tools.webSearch(config?.webSearch ?? {}),
xSearch: provider.tools.xSearch(config?.xSearch ?? {})
}
})
}
}
] as const
} as const satisfies ProviderExtensionConfig<XaiProviderSettings, XaiProvider, 'xai'>)
/**
* 核心 provider extensions 列表
*/
export const coreExtensions = [
OpenAIExtension,
AnthropicExtension,
AzureExtension,
GoogleExtension,
XaiExtension,
DeepSeekExtension,
OpenRouterExtension,
OpenAICompatibleExtension,
CherryInExtension
] as const
/**
* 核心 Provider IDs 类型
* 从 coreExtensions 数组自动提取所有 provider IDs包括 aliases 和 variants
*
*/
export type CoreProviderId = ExtractExtensionIds<(typeof coreExtensions)[number]>
type ExtensionConfigs = (typeof coreExtensions)[number]['config']
type ProviderIdsMap = UnionToIntersection<ExtensionConfigToIdResolutionMap<ExtensionConfigs>>
export const registeredProviderIds: ProviderIdsMap = (() => {
const map = {} as ProviderIdsMap
coreExtensions.forEach((ext) => {
const config = ext.config as ProviderExtensionConfig<any, any, CoreProviderId>
const name = config.name
;(map as Record<string, CoreProviderId>)[name] = name
if (config.aliases) {
config.aliases.forEach((alias) => {
;(map as Record<string, CoreProviderId>)[alias] = name
})
}
if (config.variants) {
config.variants.forEach((variant) => {
;(map as Record<string, CoreProviderId>)[`${name}-${variant.suffix}`] = name
})
}
})
return map
})()
// ==================== 初始化 Extension Registry ====================
/**
* 注册所有通用 extensions 到全局 registry
* 在模块加载时自动执行
*
* 注意:只注册通用的 provider extensionsOpenAI, Anthropic, Google 等)
* 项目特定的 extensions 应该在应用层单独注册
*/
// register() is idempotent — safe to call on HMR / re-import
extensionRegistry.registerAll(coreExtensions)
/**
* Provider 初始化错误类型
*/
class ProviderInitializationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderInitializationError'
}
}
/**
* 检查是否有对应的 Provider Extension
*/
export function hasProviderConfig(providerId: string): boolean {
return extensionRegistry.has(providerId)
}
// ==================== 导出错误类型 ====================
export { ProviderInitializationError }

View File

@@ -1,3 +1,10 @@
/**
* Provider
* utils.ts errors.ts
*/
// ==================== 私钥格式化工具 ====================
/**
* PEM头部和尾部
*/
@@ -84,3 +91,20 @@ function reconstructPemKey(key: string): string {
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
}
// ==================== 错误类 ====================
/**
* Provider
* provider
*/
export class ProviderCreationError extends Error {
constructor(
message: string,
public providerId: string,
public cause: Error
) {
super(message)
this.name = 'ProviderCreationError'
}
}

View File

@@ -1,291 +0,0 @@
/**
* AI Provider 配置工厂
* 提供类型安全的 Provider 配置构建器
*/
import type { ProviderId, ProviderSettingsMap } from './types'
/**
* 通用配置基础类型,包含所有 Provider 共有的属性
*/
export interface BaseProviderConfig {
apiKey?: string
baseURL?: string
timeout?: number
headers?: Record<string, string>
fetch?: typeof globalThis.fetch
}
/**
* 完整的配置类型结合基础配置、AI SDK 配置和特定 Provider 配置
*/
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
type ConfigHandler<T extends ProviderId> = (
builder: ProviderConfigBuilder<T>,
provider: CompleteProviderConfig<T>
) => void
const configHandlers: {
[K in ProviderId]?: ConfigHandler<K>
} = {
azure: (builder, provider) => {
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
const azureProvider = provider as CompleteProviderConfig<'azure'>
azureBuilder.withAzureConfig({
apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName
})
}
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
constructor(private providerId: T) {}
/**
* 设置 API Key
*/
withApiKey(apiKey: string): this
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
withApiKey(apiKey: string, options?: any): this {
this.config.apiKey = apiKey
// 类型安全的 OpenAI 特定配置
if (this.providerId === 'openai' && options) {
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
if (options.organization) {
openaiConfig.organization = options.organization
}
if (options.project) {
openaiConfig.project = options.project
}
}
return this
}
/**
* 设置基础 URL
*/
withBaseURL(baseURL: string) {
this.config.baseURL = baseURL
return this
}
/**
* 设置请求配置
*/
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
if (options.headers) {
this.config.headers = { ...this.config.headers, ...options.headers }
}
if (options.fetch) {
this.config.fetch = options.fetch
}
return this
}
/**
* Azure OpenAI 特定配置
*/
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
withAzureConfig(options: any): any {
if (this.providerId === 'azure') {
const azureConfig = this.config as CompleteProviderConfig<'azure'>
if (options.apiVersion) {
azureConfig.apiVersion = options.apiVersion
}
if (options.resourceName) {
azureConfig.resourceName = options.resourceName
}
}
return this
}
/**
* 设置自定义参数
*/
withCustomParams(params: Record<string, any>) {
Object.assign(this.config, params)
return this
}
/**
* 构建最终配置
*/
build(): ProviderSettingsMap[T] {
return this.config as ProviderSettingsMap[T]
}
}
/**
* Provider 配置工厂
* 提供便捷的配置创建方法
*/
export class ProviderConfigFactory {
/**
* 创建配置构建器
*/
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
return new ProviderConfigBuilder(providerId)
}
/**
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
*/
static fromProvider<T extends ProviderId>(
providerId: T,
provider: CompleteProviderConfig<T>,
options?: {
headers?: Record<string, string>
[key: string]: any
}
): ProviderSettingsMap[T] {
const builder = new ProviderConfigBuilder<T>(providerId)
// 设置基本配置
if (provider.apiKey) {
builder.withApiKey(provider.apiKey)
}
if (provider.baseURL) {
builder.withBaseURL(provider.baseURL)
}
// 设置请求配置
if (options?.headers) {
builder.withRequestConfig({
headers: options.headers
})
}
// 使用配置处理器模式 - 更加优雅和可扩展
const handler = configHandlers[providerId]
if (handler) {
handler(builder, provider)
}
// 添加其他自定义参数
if (options) {
const customOptions = { ...options }
delete customOptions.headers // 已经处理过了
if (Object.keys(customOptions).length > 0) {
builder.withCustomParams(customOptions)
}
}
return builder.build()
}
/**
* 快速创建 OpenAI 配置
*/
static createOpenAI(
apiKey: string,
options?: {
baseURL?: string
organization?: string
project?: string
}
) {
const builder = this.builder('openai')
// 使用类型安全的重载
if (options?.organization || options?.project) {
builder.withApiKey(apiKey, {
organization: options.organization,
project: options.project
})
} else {
builder.withApiKey(apiKey)
}
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
}
/**
* 快速创建 Anthropic 配置
*/
static createAnthropic(
apiKey: string,
options?: {
baseURL?: string
}
) {
return this.builder('anthropic')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
.build()
}
/**
* 快速创建 Azure OpenAI 配置
*/
static createAzureOpenAI(
apiKey: string,
options: {
baseURL: string
apiVersion?: string
resourceName?: string
}
) {
return this.builder('azure')
.withApiKey(apiKey)
.withBaseURL(options.baseURL)
.withAzureConfig({
apiVersion: options.apiVersion,
resourceName: options.resourceName
})
.build()
}
/**
* 快速创建 Google 配置
*/
static createGoogle(
apiKey: string,
options?: {
baseURL?: string
projectId?: string
location?: string
}
) {
return this.builder('google')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
.build()
}
/**
* 快速创建 Vertex AI 配置
*/
static createVertexAI() {
// credentials: {
// clientEmail: string
// privateKey: string
// },
// options?: {
// project?: string
// location?: string
// }
// return this.builder('google-vertex')
// .withGoogleCredentials(credentials)
// .withGoogleVertexConfig({
// project: options?.project,
// location: options?.location
// })
// .build()
}
static createOpenAICompatible(baseURL: string, apiKey: string) {
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
}
}
/**
* 便捷的配置创建函数
*/
export const createProviderConfig = ProviderConfigFactory.fromProvider
export const providerConfigBuilder = ProviderConfigFactory.builder

View File

@@ -4,81 +4,50 @@
// ==================== 核心管理器 ====================
// Provider 注册表管理器
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
// Provider 核心功能
export {
// 状态管理
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getImageModel,
// 工具函数
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
getTextEmbeddingModel,
hasInitializedProviders,
// 工具函数
hasProviderConfig,
// 别名支持
hasProviderConfigByAlias,
isProviderConfigAlias,
// 错误类型
ProviderInitializationError,
// 全局访问
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
// 统一Provider系统
registerProviderConfig,
resolveProviderConfigId
} from './registry'
export { coreExtensions, hasProviderConfig } from './core/initialization'
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
export { baseProviderIds, baseProviders, isBaseProvider } from './schemas'
// 类型定义
export type { AiSdkModel, ProviderError } from './types'
// 类型定义和Schema
// 类型提取工具
export type {
BaseProviderId,
CustomProviderId,
DynamicProviderRegistration,
ProviderConfig,
ProviderId
} from './schemas' // 从 schemas 导出的类型
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
export type {
AiSdkModel,
DynamicProviderRegistry,
ExtensibleProviderSettingsMap,
ProviderError,
ProviderSettingsMap,
ProviderTypeRegistrar
CoreProviderSettingsMap,
ExtensionConfigToIdResolutionMap,
ExtensionToSettingsMap,
ExtractProviderIds,
StringKeys,
UnionToIntersection
} from './types'
// ==================== 工具函数 ====================
// Provider配置工厂
// 工具函数和错误类
export { formatPrivateKey, ProviderCreationError } from './core/utils'
// ==================== Provider Extension 系统 ====================
// Extension 核心类和类型
export {
type BaseProviderConfig,
createProviderConfig,
ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './factory'
type ProviderCreatorFunction,
ProviderExtension,
type ProviderExtensionConfig,
type ProviderModule
} from './core/ProviderExtension'
// 工具函数
export { formatPrivateKey } from './utils'
// ==================== 扩展功能 ====================
// Hub Provider 功能
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
// Extension Registry
export { ExtensionRegistry, extensionRegistry } from './core/ExtensionRegistry'
export type { ProviderVariant } from './types'
export type {
ExtractToolConfig,
ExtractToolConfigMap,
ProviderId,
RegisteredProviderId,
ToolCapability,
ToolFactory,
ToolFactoryMap,
ToolFactoryPatch,
WebSearchToolConfigMap
} from './types'

View File

@@ -1,314 +0,0 @@
/**
* Provider 初始化器
* 负责根据配置创建 providers 并注册到全局管理器
* 集成了来自 ModelCreator 的特殊处理逻辑
*/
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders, type ProviderConfig } from './schemas'
/**
* Provider 初始化错误类型
*/
class ProviderInitializationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderInitializationError'
}
}
// ==================== 全局管理器导出 ====================
export { globalRegistryManagement as providerRegistry }
// ==================== 便捷访问方法 ====================
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.embeddingModel(id as any)
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
// ==================== 工具函数 ====================
/**
* 获取支持的 Providers 列表
*/
export function getSupportedProviders(): Array<{
id: string
name: string
}> {
return baseProviders.map((provider) => ({
id: provider.id,
name: provider.name
}))
}
/**
* 获取所有已初始化的 providers
*/
export function getInitializedProviders(): string[] {
return globalRegistryManagement.getRegisteredProviders()
}
/**
* 检查是否有任何已初始化的 providers
*/
export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders()
}
// ==================== 统一Provider配置系统 ====================
// 全局Provider配置存储
const providerConfigs = new Map<string, ProviderConfig>()
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
const providerConfigAliases = new Map<string, string>() // alias -> realId
/**
* 初始化内置配置 - 将baseProviders转换为统一格式
*/
function initializeBuiltInConfigs(): void {
baseProviders.forEach((provider) => {
const config: ProviderConfig = {
id: provider.id,
name: provider.name,
creator: provider.creator as any, // 类型转换以兼容多种creator签名
supportsImageGeneration: provider.supportsImageGeneration || false
}
providerConfigs.set(provider.id, config)
})
}
// 启动时自动注册内置配置
initializeBuiltInConfigs()
/**
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
*/
export function registerProviderConfig(config: ProviderConfig): boolean {
try {
// 验证配置
if (!config || !config.id || !config.name) {
return false
}
// 检查是否与已有配置冲突(包括内置配置)
if (providerConfigs.has(config.id)) {
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
}
// 存储配置(内置和用户配置统一处理)
providerConfigs.set(config.id, config)
// 处理别名
if (config.aliases && config.aliases.length > 0) {
config.aliases.forEach((alias) => {
if (providerConfigAliases.has(alias)) {
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
}
providerConfigAliases.set(alias, config.id)
})
}
return true
} catch (error) {
console.error(`Failed to register ProviderConfig:`, error)
return false
}
}
/**
* 步骤2: 创建Provider - 根据配置执行实际创建
*/
export async function createProvider(providerId: string, options: any): Promise<any> {
// 支持通过别名查找配置
const config = getProviderConfigByAlias(providerId)
if (!config) {
throw new Error(`ProviderConfig not found for id: ${providerId}`)
}
try {
let creator: (options: any) => any
if (config.creator) {
// 方式1: 直接执行 creator
creator = config.creator
} else if (config.import && config.creatorFunctionName) {
// 方式2: 动态导入并执行
const module = await config.import()
creator = (module as any)[config.creatorFunctionName]
if (!creator || typeof creator !== 'function') {
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
}
} else {
throw new Error('No valid creator method provided in ProviderConfig')
}
// 使用真实配置创建provider实例
return creator(options)
} catch (error) {
console.error(`Failed to create provider "${providerId}":`, error)
throw error
}
}
/**
* 步骤3: 注册Provider到全局管理器
*/
export function registerProvider(providerId: string, provider: any): boolean {
try {
const config = providerConfigs.get(providerId)
if (!config) {
console.error(`ProviderConfig not found for id: ${providerId}`)
return false
}
// 获取aliases配置
const aliases = config.aliases
// 处理特殊provider逻辑
if (providerId === 'openai') {
// 注册默认 openai
globalRegistryManagement.registerProvider(providerId, provider, aliases)
// 创建并注册 openai-chat 变体
const openaiChatProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
} else if (providerId === 'azure') {
globalRegistryManagement.registerProvider(providerId, provider, aliases)
// Keep the resolver's azure-chat fallback aligned with the local provider creator.
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider)
} else {
// 其他provider直接注册
globalRegistryManagement.registerProvider(providerId, provider, aliases)
}
return true
} catch (error) {
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
return false
}
}
/**
* 便捷函数: 一次性完成创建+注册
*/
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
try {
// 步骤2: 创建provider
const provider = await createProvider(providerId, options)
// 步骤3: 注册到全局管理器
return registerProvider(providerId, provider)
} catch (error) {
console.error(`Failed to create and register provider "${providerId}":`, error)
return false
}
}
/**
* 批量注册Provider配置
*/
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerProviderConfig(config)) {
successCount++
}
})
return successCount
}
/**
* 检查是否有对应的Provider配置
*/
export function hasProviderConfig(providerId: string): boolean {
return providerConfigs.has(providerId)
}
/**
* 通过别名或ID检查是否有对应的Provider配置
*/
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
const realId = resolveProviderConfigId(aliasOrId)
return providerConfigs.has(realId)
}
/**
* 获取所有Provider配置
*/
export function getAllProviderConfigs(): ProviderConfig[] {
return Array.from(providerConfigs.values())
}
/**
* 根据ID获取Provider配置
*/
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
return providerConfigs.get(providerId)
}
/**
* 通过别名或ID获取Provider配置
*/
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
// 先检查是否为别名如果是则解析为真实ID
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
return providerConfigs.get(realId)
}
/**
* 解析真实的ProviderConfig ID去别名化
*/
export function resolveProviderConfigId(aliasOrId: string): string {
return providerConfigAliases.get(aliasOrId) || aliasOrId
}
/**
* 检查是否为ProviderConfig别名
*/
export function isProviderConfigAlias(id: string): boolean {
return providerConfigAliases.has(id)
}
/**
* 获取所有ProviderConfig别名映射关系
*/
export function getAllProviderConfigAliases(): Record<string, string> {
const result: Record<string, string> = {}
providerConfigAliases.forEach((realId, alias) => {
result[alias] = realId
})
return result
}
/**
* 清理所有Provider配置和已注册的providers
*/
export function cleanup(): void {
providerConfigs.clear()
providerConfigAliases.clear() // 清理别名映射
globalRegistryManagement.clear()
// 重新初始化内置配置
initializeBuiltInConfigs()
}
export function clearAllProviders(): void {
globalRegistryManagement.clear()
}
// ==================== 导出错误类型 ====================
export { ProviderInitializationError }

View File

@@ -1,237 +0,0 @@
/**
* Provider Config 定义
*/
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import type { ProviderV3 } from '@ai-sdk/provider'
import { createXai } from '@ai-sdk/xai'
import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider'
import { createOpenRouter, type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { customProvider, wrapProvider } from 'ai'
import * as z from 'zod'
/**
* 基础 Provider IDs
*/
export const baseProviderIds = [
'openai',
'openai-chat',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'azure-responses',
'deepseek',
'openrouter',
'cherryin',
'cherryin-chat'
] as const
/**
* 基础 Provider ID Schema
*/
export const baseProviderIdSchema = z.enum(baseProviderIds)
/**
* 基础 Provider ID
*/
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export const isBaseProvider = (id: ProviderId): id is BaseProviderId => {
return baseProviderIdSchema.safeParse(id).success
}
type BaseProvider = {
id: BaseProviderId
name: string
creator: (options: any) => ProviderV3
supportsImageGeneration: boolean
}
/**
* 基础 Providers 定义
* 作为唯一数据源,避免重复维护
*/
export const baseProviders = [
{
id: 'openai',
name: 'OpenAI',
creator: createOpenAI,
supportsImageGeneration: true
},
{
id: 'openai-chat',
name: 'OpenAI Chat',
creator: (options: OpenAIProviderSettings) => {
const provider = createOpenAI(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
creator: createOpenAICompatible,
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
creator: createAnthropic,
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
{
id: 'xai',
name: 'xAI (Grok)',
creator: (options: Parameters<typeof createXai>[0]) => {
const provider = createXai(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
creator: (options: AzureOpenAIProviderSettings) => {
const provider = createAzure(options)
return customProvider({
fallbackProvider: {
// Cherry's "azure" path is the chat/deployment-based variant.
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'azure-responses',
name: 'Azure OpenAI Responses',
creator: (options: AzureOpenAIProviderSettings) => {
const provider = createAzure(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
},
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
creator: (options?: OpenRouterProviderSettings) => {
const provider = createOpenRouter(options)
return wrapProvider({ provider, languageModelMiddleware: [] })
},
supportsImageGeneration: true
},
{
id: 'cherryin',
name: 'CherryIN',
creator: createCherryIn,
supportsImageGeneration: true
},
{
id: 'cherryin-chat',
name: 'CherryIN Chat',
creator: (options: CherryInProviderSettings) => {
const provider = createCherryIn(options)
return customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
},
supportsImageGeneration: true
}
] as const satisfies BaseProvider[]
/**
* 用户自定义 Provider ID Schema
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
*/
export const customProviderIdSchema = z
.string()
.min(1)
.refine((id) => !baseProviderIds.includes(id as any), {
message: 'Custom provider ID cannot conflict with base provider IDs'
})
/**
* Provider ID Schema - 支持基础和自定义
*/
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
/**
* Provider 配置 Schema
* 用于Provider的配置验证
*/
export const providerConfigSchema = z
.object({
id: customProviderIdSchema, // 只允许自定义ID
name: z.string().min(1),
creator: z
.function({
input: z.any(),
output: z.any()
})
.optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional(),
aliases: z.array(z.string()).optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
/**
* Provider ID 类型 - 基于 zod schema 推导
*/
export type ProviderId = z.infer<typeof providerIdSchema>
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
/**
* Provider 配置类型
*/
export type ProviderConfig = z.infer<typeof providerConfigSchema>
/**
* 兼容性类型别名
* @deprecated 使用 ProviderConfig 替代
*/
export type DynamicProviderRegistration = ProviderConfig

View File

@@ -1,101 +0,0 @@
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
import { type XaiProviderSettings } from '@ai-sdk/xai'
import type {
EmbeddingModel,
EmbeddingModelUsage,
ImageModel,
ImageModelUsage,
LanguageModel,
LanguageModelUsage,
SpeechModel,
TranscriptionModel
} from 'ai'
// 导入基于 Zod 的 ProviderId 类型
import { type ProviderId as ZodProviderId } from './schemas'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
openai: OpenAIProviderSettings
'openai-responses': OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
deepseek: DeepSeekProviderSettings
}
// 动态扩展的provider类型注册表
export interface DynamicProviderRegistry {
[key: string]: any
}
// 合并基础和动态provider类型
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
// 动态ProviderId类型 - 基于 Zod Schema支持运行时扩展和验证
export type ProviderId = ZodProviderId
export interface ProviderTypeRegistrar {
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
getProviderSettings<T extends string>(providerId: T): any
}
// 重新导出所有类型供外部使用
export type {
AnthropicProviderSettings,
AzureOpenAIProviderSettings,
DeepSeekProviderSettings,
GoogleGenerativeAIProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
XaiProviderSettings
}
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
export type AiSdkProvider = ProviderV2 | ProviderV3
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
export const METHOD_MAP = {
text: 'languageModel',
image: 'imageModel',
embedding: 'embeddingModel',
transcription: 'transcriptionModel',
speech: 'speechModel'
} as const satisfies Record<AiSdkModelType, keyof ProviderV3>
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV3>
export type AiSdkModelReturnMap = {
text: LanguageModel
image: ImageModel
embedding: EmbeddingModel
transcription: TranscriptionModel
speech: SpeechModel
}
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]

View File

@@ -0,0 +1,222 @@
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
import type {
EmbeddingModel,
EmbeddingModelUsage,
ImageModel,
ImageModelUsage,
LanguageModel,
LanguageModelUsage,
SpeechModel,
TranscriptionModel
} from 'ai'
import type { coreExtensions } from '../core/initialization'
import type { ProviderExtension } from '../core/ProviderExtension'
import type { ToolFactoryMap } from './toolFactory'
// ============================================================================
// Type Utilities
// ============================================================================
/**
* 提取对象类型中的字符串键
* @example StringKeys<{ foo: 1, 0: 2 }> = 'foo'
*/
export type StringKeys<T> = Extract<keyof T, string>
/** 从 coreExtensions 自动提取的 Provider ID literal union */
export type RegisteredProviderId = StringKeys<CoreProviderSettingsMap>
/** 允许已注册 ID有自动补全和任意字符串动态 provider */
export type ProviderId = RegisteredProviderId | (string & {})
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
export type AiSdkProvider = ProviderV2 | ProviderV3
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
const METHOD_MAP = {
text: 'languageModel',
image: 'imageModel',
embedding: 'embeddingModel',
transcription: 'transcriptionModel',
speech: 'speechModel'
} as const satisfies Record<AiSdkModelType, keyof ProviderV3>
type AiSdkModelReturnMap = {
text: LanguageModel
image: ImageModel
embedding: EmbeddingModel
transcription: TranscriptionModel
speech: SpeechModel
}
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
// ============================================================================
// Provider Extension 类型定义
// ============================================================================
/** Provider 变体配置 */
export interface ProviderVariant<TSettings = any, TProvider extends ProviderV3 = ProviderV3> {
suffix: string
name: string
/** 类型安全的模型解析provider.responses(modelId) / provider.chat(modelId) */
resolveModel?: (provider: TProvider, modelId: string) => LanguageModel
/** 替换整个 provider如 azure-anthropic简单方法切换用 resolveModel */
transform?: (baseProvider: TProvider, settings?: TSettings) => ProviderV3
toolFactories?: ToolFactoryMap<TProvider>
}
// ============================================================================
// Provider ID Type Extraction Utilities
// ============================================================================
/**
* Extract all Provider IDs from an extension config
* 保留字面量类型,避免被推断为 string
*/
export type ExtractProviderIds<TConfig> = TConfig extends { name: infer TName }
? TName extends string
?
| TName
| (TConfig extends { aliases: infer TAliases }
? TAliases extends readonly string[]
? TAliases[number]
: never
: never)
| (TConfig extends { variants: infer TVariants }
? TVariants extends readonly any[]
? TVariants[number] extends { suffix: infer TSuffix }
? TSuffix extends string
? `${TName}-${TSuffix}`
: never
: never
: never
: never)
: never
: never
/**
* Extract Provider IDs from a ProviderExtension instance
*/
export type ExtractExtensionIds<T> = T extends { config: infer TConfig } ? ExtractProviderIds<TConfig> : never
/**
* Extract Settings type from a ProviderExtension instance
*
* @example
* ```typescript
* type Settings = ExtractExtensionSettings<typeof OpenAIExtension>
* // => OpenAIProviderSettings
* ```
*/
export type ExtractExtensionSettings<T> = T extends ProviderExtension<infer TSettings, any, any> ? TSettings : never
/**
* Map all Provider IDs from an Extension to its Settings type
*/
export type ExtensionToSettingsMap<T> = T extends ProviderExtension<infer TSettings, any, infer TConfig>
? { [K in ExtractProviderIds<TConfig>]: TSettings }
: never
// ============================================================================
// Provider Settings Map - Auto-extracted from Extensions
// ============================================================================
/**
* Core Provider Settings Map
*/
export type CoreProviderSettingsMap = UnionToIntersection<ExtensionToSettingsMap<(typeof coreExtensions)[number]>>
// 辅助类型:提取所有变体 ID
type ExtractVariantIds<TConfig, TName extends string> = TConfig extends {
variants: readonly { suffix: infer TSuffix extends string }[]
}
? `${TName}-${TSuffix}`
: never
export type ExtensionConfigToIdResolutionMap<TConfig> = TConfig extends { name: infer TName extends string }
? {
readonly [K in
| TName
| (TConfig extends { aliases: readonly (infer TAlias extends string)[] } ? TAlias : never)
| ExtractVariantIds<TConfig, TName>]: K extends ExtractVariantIds<TConfig, TName>
? K // 变体 → 自身
: TName // 基础名和别名 → TName
}
: never
/**
* Provider IDs Map Type with Literal Type Inference
*/
export type UnionToIntersection<U> = (U extends any ? (x: U) => void : never) extends (x: infer I) => void ? I : never
export type { ToolCapability, ToolFactory, ToolFactoryMap, ToolFactoryPatch } from './toolFactory'
// ============================================================================
// Tool Config Type Extraction (from extension declarations via as const)
// ============================================================================
/** Extract a capability's config type from an extension's toolFactories */
export type ExtractToolConfig<TExt, K extends string> = TExt extends {
config: { toolFactories?: { [P in K]?: (provider: any) => (config: infer C) => any } }
}
? C
: never
/** Extract config from variant-level toolFactories (e.g., openai-chat) */
type ExtractVariantToolConfig<TExt, K extends string> = TExt extends {
config: {
name: infer TName extends string
variants?: readonly (infer V)[]
}
}
? V extends {
suffix: infer TSuffix extends string
toolFactories?: { [P in K]?: (provider: any) => (config: infer C) => any }
}
? { id: `${TName}-${TSuffix}`; config: C }
: never
: never
/** Extract { [providerId]: ConfigType } map from all extensions for a capability */
export type ExtractToolConfigMap<TExtUnion, K extends string> = UnionToIntersection<
| (TExtUnion extends any
? ExtractToolConfig<TExtUnion, K> extends never
? never
: TExtUnion extends { config: { name: infer TName extends string } }
? { [P in TName]?: ExtractToolConfig<TExtUnion, K> }
: never
: never)
// Variant configs: name-suffix → config
| (TExtUnion extends any
? ExtractVariantToolConfig<TExtUnion, K> extends never
? never
: ExtractVariantToolConfig<TExtUnion, K> extends { id: infer TId extends string; config: infer C }
? { [P in TId]?: C }
: never
: never)
>
/** Auto-extracted from coreExtensions' toolFactories.webSearch declarations */
export type WebSearchToolConfigMap = ExtractToolConfigMap<(typeof coreExtensions)[number], 'webSearch'>

View File

@@ -0,0 +1,32 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import type { ToolSet } from 'ai'
/**
* 跨 provider 的工具能力标识
*
* 各 SDK 的工具键名不同OpenAI: webSearch, Anthropic: webSearch_20250305, Google: googleSearch
* 但表达的是同一种能力。Plugin 通过 ToolCapability 进行跨 provider 统一查找。
*/
export type ToolCapability = 'webSearch' | 'fileSearch' | 'codeExecution' | 'urlContext'
/** 工具工厂返回的 patch描述要合并到 params 的修改 */
export interface ToolFactoryPatch {
tools?: ToolSet
providerOptions?: Record<string, any>
}
/**
* 工具工厂函数 — 形状约束
*
* 使用 `...args: any[]` 而非 `config: Record<string, any>`
* 这样 `as const satisfies` 不会擦除声明时的具体 config 类型。
* `ExtractToolConfig` 可从声明中提取具体 config 类型。
*/
export type ToolFactory<TProvider extends ProviderV3 = ProviderV3> = (
provider: TProvider
) => (...args: any[]) => ToolFactoryPatch
/** Map of ToolCapability keys to their factory functions. */
export type ToolFactoryMap<TProvider extends ProviderV3 = ProviderV3> = {
[K in ToolCapability]?: ToolFactory<TProvider>
}

View File

@@ -5,11 +5,10 @@
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { createMockImageModel, createMockLanguageModel, createMockProviderV3, mockProviderConfigs } from '@test-utils'
import { generateImage, generateText, streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../../../__tests__'
import { globalModelResolver } from '../../models'
import { ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
@@ -29,31 +28,15 @@ vi.mock('ai', async (importOriginal) => {
}
})
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn(),
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
vi.mock('../../models', () => ({
globalModelResolver: {
resolveLanguageModel: vi.fn(),
resolveImageModel: vi.fn()
}
}))
describe('RuntimeExecutor - Model Resolution', () => {
let executor: RuntimeExecutor<'openai'>
let executor: RuntimeExecutor
let mockLanguageModel: LanguageModelV3
let mockImageModel: ImageModelV3
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
specificationVersion: 'v3',
provider: 'openai',
@@ -66,8 +49,14 @@ describe('RuntimeExecutor - Model Resolution', () => {
modelId: 'dall-e-3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel)
mockProvider = createMockProviderV3({
provider: 'openai',
languageModel: vi.fn(() => mockLanguageModel),
imageModel: vi.fn(() => mockImageModel)
})
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
vi.mocked(generateText).mockResolvedValue({
text: 'Test response',
finishReason: 'stop',
@@ -89,61 +78,16 @@ describe('RuntimeExecutor - Model Resolution', () => {
})
describe('Language Model Resolution (String modelId)', () => {
it('should resolve string modelId using globalModelResolver', async () => {
it('should resolve string modelId through provider', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
mockProviderConfigs.openai
)
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
})
it('should pass provider settings to model resolver', async () => {
const customExecutor = RuntimeExecutor.create('anthropic', {
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
await customExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('claude-3-5-sonnet', 'anthropic', {
apiKey: 'sk-test',
baseURL: 'https://api.anthropic.com'
})
})
it('should resolve traditional format modelId', async () => {
await executor.generateText({
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4-turbo', 'openai', expect.any(Object))
})
it('should resolve namespaced format modelId', async () => {
await executor.generateText({
model: 'aihubmix|anthropic|claude-3',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'aihubmix|anthropic|claude-3',
'openai',
expect.any(Object)
)
})
it('should use resolved model for generation', async () => {
it('should pass resolved model to generateText', async () => {
await executor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Hello' }]
@@ -156,13 +100,31 @@ describe('RuntimeExecutor - Model Resolution', () => {
)
})
it('should resolve traditional format modelId', async () => {
await executor.generateText({
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'Test' }]
})
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4-turbo')
})
it('should resolve namespaced format modelId', async () => {
await executor.generateText({
model: 'aihubmix|anthropic|claude-3',
messages: [{ role: 'user', content: 'Test' }]
})
expect(mockProvider.languageModel).toHaveBeenCalledWith('aihubmix|anthropic|claude-3')
})
it('should work with streamText', async () => {
await executor.streamText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Stream test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled()
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: mockLanguageModel
@@ -184,8 +146,8 @@ describe('RuntimeExecutor - Model Resolution', () => {
messages: [{ role: 'user', content: 'Test' }]
})
// Should NOT call resolver for direct model
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
// Should NOT call provider for direct model
expect(mockProvider.languageModel).not.toHaveBeenCalled()
// Should use the model directly
expect(generateText).toHaveBeenCalledWith(
@@ -195,42 +157,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
)
})
it('should accept V2 model object without validation (plugin engine handles it)', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'openai',
modelId: 'gpt-4',
doGenerate: vi.fn()
} as any
// The plugin engine accepts any model object directly without validation
// V3 validation only happens when resolving string modelIds
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept any model object without checking specification version', async () => {
const v2Model = {
specificationVersion: 'v2',
provider: 'custom-provider',
modelId: 'custom-model',
doGenerate: vi.fn()
} as any
// Direct model objects bypass validation
// The executor trusts that plugins/users provide valid models
await expect(
executor.generateText({
model: v2Model,
messages: [{ role: 'user', content: 'Test' }]
})
).resolves.toBeDefined()
})
it('should accept model object with streamText', async () => {
const directModel = createMockLanguageModel({
specificationVersion: 'v3'
@@ -241,7 +167,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
messages: [{ role: 'user', content: 'Stream' }]
})
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
expect(mockProvider.languageModel).not.toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
model: directModel
@@ -251,13 +177,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
})
describe('Image Model Resolution', () => {
it('should resolve string image modelId using globalModelResolver', async () => {
it('should resolve string image modelId through provider', async () => {
await executor.generateImage({
model: 'dall-e-3',
prompt: 'A beautiful sunset'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
})
it('should accept direct ImageModelV3 object', async () => {
@@ -272,7 +198,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
prompt: 'Test image'
})
expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled()
expect(mockProvider.imageModel).not.toHaveBeenCalled()
expect(generateImage).toHaveBeenCalledWith(
expect.objectContaining({
model: directImageModel
@@ -286,12 +212,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
prompt: 'Namespaced image'
})
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai')
expect(mockProvider.imageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3')
})
it('should throw ImageModelResolutionError on resolution failure', async () => {
const resolutionError = new Error('Model not found')
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError)
mockProvider.imageModel.mockImplementation(() => {
throw new Error('Model not found')
})
await expect(
executor.generateImage({
@@ -302,7 +229,9 @@ describe('RuntimeExecutor - Model Resolution', () => {
})
it('should include modelId and providerId in ImageModelResolutionError', async () => {
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found'))
mockProvider.imageModel.mockImplementation(() => {
throw new Error('Not found')
})
try {
await executor.generateImage({
@@ -337,101 +266,70 @@ describe('RuntimeExecutor - Model Resolution', () => {
describe('Provider-Specific Model Resolution', () => {
it('should resolve models for OpenAI provider', async () => {
const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
const openaiModel = createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' })
const openaiProvider = createMockProviderV3({
provider: 'openai',
languageModel: vi.fn(() => openaiModel)
})
const openaiExecutor = RuntimeExecutor.create('openai', openaiProvider, mockProviderConfigs.openai)
await openaiExecutor.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object))
expect(openaiProvider.languageModel).toHaveBeenCalledWith('gpt-4')
})
it('should resolve models for Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet' })
const anthropicProvider = createMockProviderV3({
provider: 'anthropic',
languageModel: vi.fn(() => anthropicModel)
})
const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'claude-3-5-sonnet',
'anthropic',
expect.any(Object)
)
expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet')
})
it('should resolve models for Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash' })
const googleProvider = createMockProviderV3({
provider: 'google',
languageModel: vi.fn(() => googleModel)
})
const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google)
await googleExecutor.generateText({
model: 'gemini-2.0-flash',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gemini-2.0-flash',
'google',
expect.any(Object)
)
expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash')
})
it('should resolve models for OpenAI-compatible provider', async () => {
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible'])
const compatModel = createMockLanguageModel({ provider: 'openai-compatible', modelId: 'custom-model' })
const compatProvider = createMockProviderV3({
provider: 'openai-compatible',
languageModel: vi.fn(() => compatModel)
})
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(
compatProvider,
mockProviderConfigs['openai-compatible']
)
await compatibleExecutor.generateText({
model: 'custom-model',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'custom-model',
'openai-compatible',
expect.any(Object)
)
})
})
describe('OpenAI Mode Handling', () => {
it('should pass mode setting to model resolver', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'chat'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'chat'
})
)
})
it('should handle responses mode', async () => {
const executorWithMode = RuntimeExecutor.create('openai', {
...mockProviderConfigs.openai,
mode: 'responses'
})
await executorWithMode.generateText({
model: 'gpt-4',
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
'gpt-4',
'openai',
expect.objectContaining({
mode: 'responses'
})
)
expect(compatProvider.languageModel).toHaveBeenCalledWith('custom-model')
})
})
@@ -442,11 +340,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
messages: [{ role: 'user', content: 'Test' }]
})
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object))
expect(mockProvider.languageModel).toHaveBeenCalledWith('')
})
it('should handle model resolution errors gracefully', async () => {
vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found'))
mockProvider.languageModel.mockImplementation(() => {
throw new Error('Model not found')
})
await expect(
executor.generateText({
@@ -465,7 +365,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
await Promise.all(promises)
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3)
expect(mockProvider.languageModel).toHaveBeenCalledTimes(3)
})
it('should accept model object even without specificationVersion', async () => {
@@ -476,7 +376,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
} as any
// Plugin engine doesn't validate direct model objects
// It's the user's responsibility to provide valid models
await expect(
executor.generateText({
model: invalidModel,
@@ -492,7 +391,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
specificationVersion: 'v3'
})
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model)
mockProvider.languageModel.mockReturnValue(v3Model)
await executor.generateText({
model: 'gpt-4',
@@ -516,7 +415,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
} as any
// Direct models bypass validation in the plugin engine
// Only resolved models (from string IDs) are validated
await expect(
executor.generateText({
model: v1Model,

View File

@@ -1,52 +1,56 @@
import type { ImageModelV3 } from '@ai-sdk/provider'
import { createMockImageModel, createMockProviderV3 } from '@test-utils'
import { generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { type AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock dependencies
vi.mock('ai', () => ({
experimental_generateImage: vi.fn(),
generateImage: vi.fn(),
jsonSchema: vi.fn((schema) => schema),
NoImageGeneratedError: class NoImageGeneratedError extends Error {
static isInstance = vi.fn()
constructor() {
super('No image generated')
this.name = 'NoImageGeneratedError'
vi.mock('ai', async (importOriginal) => {
const actual = (await importOriginal()) as Record<string, unknown>
return {
...actual,
experimental_generateImage: vi.fn(),
generateImage: vi.fn(),
jsonSchema: vi.fn((schema) => schema),
NoImageGeneratedError: class NoImageGeneratedError extends Error {
static isInstance = vi.fn()
constructor() {
super('No image generated')
this.name = 'NoImageGeneratedError'
}
}
}
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
imageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
})
describe('RuntimeExecutor.generateImage', () => {
let executor: RuntimeExecutor<'openai'>
let executor: RuntimeExecutor
let mockImageModel: ImageModelV3
let mockProvider: any
let mockGenerateImageResult: any
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks()
// Create executor instance
executor = RuntimeExecutor.create('openai', {
apiKey: 'test-key'
})
// Mock image model
mockImageModel = {
mockImageModel = createMockImageModel({
modelId: 'dall-e-3',
provider: 'openai'
} as ImageModelV3
})
// Create mock provider with imageModel as a spy
mockProvider = createMockProviderV3({
provider: 'openai',
imageModel: vi.fn(() => mockImageModel)
})
// Create executor instance
executor = RuntimeExecutor.create('openai', mockProvider, {
apiKey: 'test-key'
})
// Mock generateImage result
mockGenerateImageResult = {
@@ -71,8 +75,6 @@ describe('RuntimeExecutor.generateImage', () => {
responses: []
}
// Setup mocks to avoid "No providers registered" error
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
})
@@ -80,7 +82,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@@ -96,7 +98,8 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A beautiful landscape'
})
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
// Pre-created model is used directly, provider.imageModel is not called
expect(mockProvider.imageModel).not.toHaveBeenCalled()
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
@@ -224,6 +227,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@@ -269,6 +273,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@@ -309,6 +314,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@@ -325,7 +331,8 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to get image model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
// Since mockProvider.imageModel is already a vi.fn() spy, we can mock it directly
mockProvider.imageModel.mockImplementation(() => {
throw modelError
})
@@ -336,7 +343,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should handle ImageModelResolutionError correctly', async () => {
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
mockProvider.imageModel.mockImplementation(() => {
throw resolutionError
})
@@ -353,7 +360,7 @@ describe('RuntimeExecutor.generateImage', () => {
it('should handle ImageModelResolutionError without provider', async () => {
const resolutionError = new ImageModelResolutionError('unknown-model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
mockProvider.imageModel.mockImplementation(() => {
throw resolutionError
})
@@ -398,6 +405,7 @@ describe('RuntimeExecutor.generateImage', () => {
const executorWithPlugin = RuntimeExecutor.create(
'openai',
mockProvider,
{
apiKey: 'test-key'
},
@@ -436,23 +444,43 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Multiple providers support', () => {
it('should work with different providers', async () => {
const googleExecutor = RuntimeExecutor.create('google', {
const googleImageModel = createMockImageModel({
provider: 'google',
modelId: 'imagen-3.0-generate-002'
})
const googleProvider = createMockProviderV3({
provider: 'google',
imageModel: vi.fn(() => googleImageModel)
})
const googleExecutor = RuntimeExecutor.create('google', googleProvider, {
apiKey: 'google-key'
})
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
expect(googleProvider.imageModel).toHaveBeenCalledWith('imagen-3.0-generate-002')
})
it('should support xAI Grok image models', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', {
const xaiImageModel = createMockImageModel({
provider: 'xai',
modelId: 'grok-2-image'
})
const xaiProvider = createMockProviderV3({
provider: 'xai',
imageModel: vi.fn(() => xaiImageModel)
})
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
expect(xaiProvider.imageModel).toHaveBeenCalledWith('grok-2-image')
})
})

View File

@@ -3,18 +3,18 @@
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
createMockProviderV3,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
} from '@test-utils'
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
@@ -26,28 +26,26 @@ vi.mock('ai', async (importOriginal) => {
}
})
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let executor: RuntimeExecutor
let mockLanguageModel: any
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
mockProvider = createMockProviderV3({
provider: 'openai',
languageModel: vi.fn(() => mockLanguageModel)
})
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
@@ -231,75 +229,87 @@ describe('RuntimeExecutor.generateText', () => {
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
const anthropicProvider = createMockProviderV3({
provider: 'anthropic',
languageModel: vi.fn(() => anthropicModel)
})
const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
const googleProvider = createMockProviderV3({
provider: 'google',
languageModel: vi.fn(() => googleModel)
})
const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
const xaiProvider = createMockProviderV3({
provider: 'xai',
languageModel: vi.fn(() => xaiModel)
})
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, mockProviderConfigs.xai)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
expect(xaiProvider.languageModel).toHaveBeenCalledWith('grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
const deepseekProvider = createMockProviderV3({
provider: 'deepseek',
languageModel: vi.fn(() => deepseekModel)
})
const deepseekExecutor = RuntimeExecutor.create('deepseek', deepseekProvider, mockProviderConfigs.deepseek)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
expect(deepseekProvider.languageModel).toHaveBeenCalledWith('deepseek-chat')
})
})
@@ -325,7 +335,9 @@ describe('RuntimeExecutor.generateText', () => {
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
testPlugin
])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
@@ -364,7 +376,10 @@ describe('RuntimeExecutor.generateText', () => {
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
const executorWithPlugins = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
plugin1,
plugin2
])
await executorWithPlugins.generateText({
model: 'gpt-4',
@@ -404,7 +419,9 @@ describe('RuntimeExecutor.generateText', () => {
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
errorPlugin
])
await expect(
executorWithPlugin.generateText({
@@ -425,7 +442,7 @@ describe('RuntimeExecutor.generateText', () => {
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
mockProvider.languageModel.mockImplementationOnce(() => {
throw error
})

View File

@@ -5,10 +5,10 @@
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import { createMockImageModel, createMockLanguageModel, createMockMiddleware } from '@test-utils'
import { wrapLanguageModel } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createMockImageModel, createMockLanguageModel, createMockMiddleware } from '../../../__tests__'
import { ModelResolutionError, RecursiveDepthError } from '../../errors'
import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins'
import { PluginEngine } from '../pluginEngine'

View File

@@ -3,12 +3,17 @@
* Tests streaming text generation across all providers with various parameters
*/
import {
collectStreamChunks,
createMockLanguageModel,
createMockProviderV3,
mockProviderConfigs,
testMessages
} from '@test-utils'
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
@@ -20,28 +25,25 @@ vi.mock('ai', async (importOriginal) => {
}
})
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let executor: RuntimeExecutor
let mockLanguageModel: any
let mockProvider: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
mockProvider = createMockProviderV3({
provider: 'openai',
languageModel: () => mockLanguageModel
})
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
})
describe('Basic Functionality', () => {
@@ -416,7 +418,9 @@ describe('RuntimeExecutor.streamText', () => {
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
testPlugin
])
const mockStream = {
textStream: (async function* () {
@@ -509,7 +513,9 @@ describe('RuntimeExecutor.streamText', () => {
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
errorPlugin
])
await expect(
executorWithPlugin.streamText({
@@ -519,11 +525,12 @@ describe('RuntimeExecutor.streamText', () => {
).rejects.toThrow('Stream error')
// onError receives the original error and context with core fields
// context.model is the resolved LanguageModel (updated after resolveModel hook)
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
model: 'gpt-4'
model: expect.objectContaining({ modelId: 'gpt-4' })
})
)
})

View File

@@ -2,34 +2,54 @@
* 运行时执行器
* 专注于插件化的AI调用处理
*/
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
import type { ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import { generateImage as _generateImage, generateText as _generateText, streamText as _streamText } from 'ai'
import {
createProviderRegistry,
embedMany as _embedMany,
generateImage as _generateImage,
generateText as _generateText,
streamText as _streamText
} from 'ai'
import { globalModelResolver } from '../models'
import { type ModelConfig } from '../models/types'
import { isV3Model } from '../models/utils'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { type ProviderId } from '../providers'
import { type AiPlugin, definePlugin } from '../plugins'
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import type { generateImageParams, generateTextParams, RuntimeConfig, streamTextParams } from './types'
import type {
EmbedManyParams,
EmbedManyResult,
generateImageParams,
generateImageResult,
generateTextParams,
RuntimeConfig,
streamTextParams
} from './types'
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
export class RuntimeExecutor<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
> {
public pluginEngine: PluginEngine<T>
// private options: ProviderSettingsMap[T]
private config: RuntimeConfig<T>
private config: RuntimeConfig<TSettingsMap, T>
private registry: ReturnType<typeof createProviderRegistry>
constructor(config: RuntimeConfig<T>) {
// if (!isProviderSupported(config.providerId)) {
// throw new Error(`Unsupported provider: ${config.providerId}`)
// }
// 存储options供后续使用
// this.options = config.options
constructor(config: RuntimeConfig<TSettingsMap, T>) {
this.config = config
// 创建插件客户端
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
// Some v3 providers (e.g., @openrouter/ai-sdk-provider) expose textEmbeddingModel
// but not embeddingModel. Patch for AI SDK registry compatibility.
const provider = config.provider
if (!provider.embeddingModel && provider.textEmbeddingModel) {
provider.embeddingModel = (modelId: string) => provider.textEmbeddingModel!(modelId)
}
this.registry = createProviderRegistry({
[config.providerId]: provider
})
}
private createResolveModelPlugin() {
@@ -58,8 +78,9 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
private createConfigureContextPlugin() {
return definePlugin({
name: '_internal_configureContext',
configureContext: async (context: AiRequestContext) => {
context.executor = this
configureContext: async () => {
// Placeholder for future context configuration
// Previously set executor and baseProvider, now handled by registry
}
})
}
@@ -120,7 +141,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
/**
* 生成图像
*/
generateImage(params: generateImageParams): Promise<ReturnType<typeof _generateImage>> {
async generateImage(params: generateImageParams): Promise<generateImageResult> {
try {
const { model } = params
@@ -148,19 +169,39 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}
}
/**
* 批量嵌入文本
*/
async embedMany(params: EmbedManyParams): Promise<EmbedManyResult> {
const { model: modelOrId, ...options } = params
// 解析 embedding 模型
const embeddingModel =
typeof modelOrId === 'string'
? this.registry.embeddingModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
: modelOrId
return _embedMany({
model: embeddingModel,
...options
})
}
// === 辅助方法 ===
/**
* 解析模型:将字符串 modelId 解析为 model 对象
* middleware 的应用由 pluginEngine 统一处理
*
* 对于有 modelResolver 的配置(如 xAI responses, OpenAI chat
* 使用 resolver 函数解析模型,而不是通过 registry.languageModel()。
* resolver 在 extension 声明处类型安全地捕获了具体 provider 方法。
*/
private async resolveModel(modelOrId: LanguageModel): Promise<LanguageModelV3> {
if (typeof modelOrId === 'string') {
return await globalModelResolver.resolveLanguageModel(
modelOrId,
this.config.providerId,
this.config.providerSettings
)
if (this.config.modelResolver) {
return this.config.modelResolver(modelOrId)
}
return this.registry.languageModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
} else {
if (!isV3Model(modelOrId)) {
throw new Error(
@@ -178,13 +219,8 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
private async resolveImageModel(modelOrId: ImageModelV3 | string): Promise<ImageModelV3> {
try {
if (typeof modelOrId === 'string') {
// 字符串modelId使用新的ModelResolver解析
return await globalModelResolver.resolveImageModel(
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
this.config.providerId // fallback provider
)
return this.registry.imageModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
} else {
// 已经是模型,直接返回
return modelOrId
}
} catch (error) {
@@ -201,27 +237,37 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
/**
* 创建执行器 - 支持已知provider的类型安全
*/
static create<T extends ProviderId>(
static create<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: ModelConfig<T>['providerSettings'],
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return new RuntimeExecutor({
provider: ProviderV3,
options: TSettingsMap[T],
plugins?: AiPlugin[],
modelResolver?: (modelId: string) => any
): RuntimeExecutor<TSettingsMap, T> {
return new RuntimeExecutor<TSettingsMap, T>({
providerId,
provider,
providerSettings: options,
plugins
plugins,
modelResolver
})
}
/**
* 创建OpenAI Compatible执行器
* ✅ Now accepts provider instance directly
*/
static createOpenAICompatible(
options: ModelConfig<'openai-compatible'>['providerSettings'],
provider: ProviderV3, // ✅ Accept provider instance
options: CoreProviderSettingsMap['openai-compatible'],
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({
): RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'> {
return new RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'>({
providerId: 'openai-compatible',
provider, // ✅ Pass provider to config
providerSettings: options,
plugins
})

View File

@@ -7,76 +7,113 @@
export { RuntimeExecutor } from './executor'
// 导出类型
export type { RuntimeConfig } from './types'
export type { EmbedManyParams, EmbedManyResult, RuntimeConfig } from './types'
// === 便捷工厂函数 ===
import { type AiPlugin } from '../plugins'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
import { extensionRegistry } from '../providers'
import { type CoreProviderSettingsMap, type StringKeys } from '../providers/types'
import { RuntimeExecutor } from './executor'
/**
* 创建运行时执行器 - 支持类型安全的已知provider
* 自动确保 provider 已初始化
*/
export function createExecutor<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return RuntimeExecutor.create(providerId, options, plugins)
}
export async function createExecutor<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(providerId: T, options: TSettingsMap[T], plugins?: AiPlugin[]): Promise<RuntimeExecutor<TSettingsMap, T>> {
if (!extensionRegistry.has(providerId)) {
throw new Error(`Provider extension "${providerId}" not registered`)
}
/**
* 创建OpenAI Compatible执行器
*/
export function createOpenAICompatibleExecutor(
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return RuntimeExecutor.createOpenAICompatible(options, plugins)
}
const provider = await extensionRegistry.createProvider(providerId, options || {})
// === 直接调用API无需创建executor实例===
// Extract model resolver from variant's resolveModel declaration (type-safe at extension level)
const resolver = extensionRegistry.getModelResolver(providerId as string)
const modelResolver = resolver ? (modelId: string) => resolver(provider, modelId) : undefined
return RuntimeExecutor.create<TSettingsMap, T>(providerId, provider, options, plugins, modelResolver)
}
/**
* 直接流式文本生成
*/
export async function streamText<T extends ProviderId>(
export async function streamText<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
options: TSettingsMap[T],
params: Parameters<RuntimeExecutor<TSettingsMap, T>['streamText']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['streamText']>> {
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
return executor.streamText(params)
}
/**
* 直接生成文本
*/
export async function generateText<T extends ProviderId>(
export async function generateText<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
options: TSettingsMap[T],
params: Parameters<RuntimeExecutor<TSettingsMap, T>['generateText']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['generateText']>> {
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
return executor.generateText(params)
}
/**
* 直接生成图像 - 支持middlewares
*/
export async function generateImage<T extends ProviderId>(
export async function generateImage<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
options: TSettingsMap[T],
params: Parameters<RuntimeExecutor<TSettingsMap, T>['generateImage']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
const executor = createExecutor(providerId, options, plugins)
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['generateImage']>> {
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
return executor.generateImage(params)
}
/**
* 直接批量嵌入文本
* AI SDK v6 只有 embedMany没有 embed
*/
export async function embedMany<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
>(
providerId: T,
options: TSettingsMap[T],
params: Parameters<RuntimeExecutor<TSettingsMap, T>['embedMany']>[0],
plugins?: AiPlugin[]
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['embedMany']>> {
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
return executor.embedMany(params)
}
/**
* 创建 OpenAI Compatible 执行器
*/
export async function createOpenAICompatibleExecutor(
options: CoreProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): Promise<RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'>> {
const provider = await extensionRegistry.createProvider('openai-compatible', options)
return RuntimeExecutor.createOpenAICompatible(provider, options, plugins)
}
// === Agent 功能预留 ===
// 未来将在 ../agents/ 文件夹中添加:
// - AgentExecutor.ts

View File

@@ -14,13 +14,13 @@ import {
type StreamTextParams,
type StreamTextResult
} from '../plugins'
import { type ProviderId } from '../providers/types'
import type { RegisteredProviderId } from '../providers'
/**
* 插件增强的 AI 客户端
* 专注于插件处理不暴露用户API
*/
export class PluginEngine<T extends ProviderId = ProviderId> {
export class PluginEngine<T extends string = RegisteredProviderId> {
/**
* Plugin storage with explicit any/any generics
*
@@ -36,7 +36,6 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
constructor(
private readonly providerId: T,
// private readonly options: ProviderSettingsMap[T],
plugins: AiPlugin[] = []
) {
this.basePlugins = plugins
@@ -352,6 +351,9 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
throw new ModelResolutionError(modelId, this.providerId)
}
resolvedModel = resolved
// 更新 context.model 为已解析的 LanguageModel 实例
// 后续 plugin如 providerToolPlugin需要 model.provider 来识别聚合供应商的协议
context.model = resolvedModel
}
if (!resolvedModel) {
@@ -359,7 +361,10 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 2.5 应用 context.middlewares 到模型
if (typeof model !== 'string' && context.middlewares && context.middlewares.length > 0) {
if (context.middlewares && context.middlewares.length > 0) {
if (typeof resolvedModel === 'string') {
throw new Error(`Model must be resolved before applying middlewares, got string: ${resolvedModel}`)
}
resolvedModel = wrapLanguageModel({
model: resolvedModel as LanguageModelV3,
middleware: context.middlewares

View File

@@ -1,24 +1,43 @@
/**
* Runtime 层类型定义
*/
import type { ImageModelV3 } from '@ai-sdk/provider'
import type { generateImage, generateText, streamText } from 'ai'
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { embedMany, generateImage, generateText, streamText } from 'ai'
import { type ModelConfig } from '../models/types'
import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
/**
* 运行时执行器配置
*
* @typeParam TSettingsMap - Provider Settings Map默认 CoreProviderSettingsMap
* @typeParam T - Provider ID 类型(从 TSettingsMap 的键推断)
*/
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
export interface RuntimeConfig<
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
> {
providerId: T
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
provider: ProviderV3
providerSettings: TSettingsMap[T]
plugins?: AiPlugin[]
/**
* 模型解析函数
* 从 variant 的 resolveModel 声明中提取(类型安全在 extension 声明处保证)。
* 不提供时使用 AI SDK 默认的 provider.languageModel()。
*/
modelResolver?: (modelId: string) => any
}
export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'model'> & {
model: string | ImageModelV3
}
export type generateImageResult = Awaited<ReturnType<typeof generateImage>>
export type generateTextParams = Parameters<typeof generateText>[0]
export type streamTextParams = Parameters<typeof streamText>[0]
// Embedding types (AI SDK v6 only has embedMany, no embed)
export type EmbedManyParams = Omit<Parameters<typeof embedMany>[0], 'model'> & {
model: string | EmbeddingModelV3
}
export type EmbedManyResult = Awaited<ReturnType<typeof embedMany>>

View File

@@ -0,0 +1 @@
export type PlainObject = Record<string, any>

View File

@@ -0,0 +1,17 @@
import type { PlainObject } from '../types'
export const isPlainObject = (value: unknown): value is PlainObject => {
return typeof value === 'object' && value !== null && !Array.isArray(value)
}
export function deepMergeObjects<T extends PlainObject>(target: T, source: PlainObject): T {
const result: PlainObject = { ...target }
Object.entries(source).forEach(([key, value]) => {
if (isPlainObject(value) && isPlainObject(result[key])) {
result[key] = deepMergeObjects(result[key], value)
} else {
result[key] = value
}
})
return result as T
}

View File

@@ -6,46 +6,38 @@
// 导入内部使用的类和函数
// ==================== 主要用户接口 ====================
export {
createExecutor,
createOpenAICompatibleExecutor,
generateImage,
generateText,
streamText
} from './core/runtime'
export { createExecutor, embedMany, generateImage, generateText, streamText } from './core/runtime'
// ==================== Embedding 类型 ====================
export type { EmbedManyParams, EmbedManyResult } from './core/runtime'
// ==================== 高级API ====================
export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models'
export { isV2Model, isV3Model } from './core/models'
// ==================== 插件系统 ====================
export type {
AiPlugin,
AiRequestContext,
AiRequestMetadata,
GenerateTextParams,
GenerateTextResult,
HookResult,
PluginManagerConfig,
RecursiveCallFn,
StreamTextParams,
StreamTextResult
} from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins'
export { definePlugin } from './core/plugins'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== 类型工具 ====================
export type { AiSdkModel } from './core/providers'
// ==================== 选项 ====================
export {
createAnthropicOptions,
createGoogleOptions,
createOpenAIOptions,
type ExtractProviderOptions,
mergeProviderOptions,
type ProviderOptionsMap,
type TypedProviderOptions
} from './core/options'
export type {
AiSdkModel,
ExtractToolConfig,
ExtractToolConfigMap,
ProviderId,
ToolCapability,
ToolFactory,
ToolFactoryMap,
ToolFactoryPatch,
WebSearchToolConfigMap
} from './core/providers'
// ==================== 错误处理 ====================
export {
@@ -53,11 +45,6 @@ export {
ModelResolutionError,
ParameterValidationError,
PluginExecutionError,
ProviderConfigError,
RecursiveDepthError,
TemplateLoadError
} from './core/errors'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'

View File

@@ -1,2 +1,2 @@
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
export type { AiPlugin, AiRequestContext } from './core/plugins/types'

View File

@@ -1,12 +1,12 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
* Common Test Utilities
* General-purpose helper functions for testing
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
import type { ProviderId } from '../mocks/providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../mocks/providers'
/**
* Creates a test provider with streaming support

View File

@@ -10,15 +10,14 @@ import type {
LanguageModelV3Middleware,
ProviderV3
} from '@ai-sdk/provider'
import type { ProviderId } from '@test-utils'
import type { Tool, ToolSet } from 'ai'
import { tool } from 'ai'
import { MockLanguageModelV3 } from 'ai/test'
import { vi } from 'vitest'
import * as z from 'zod'
import type { StreamTextParams, StreamTextResult } from '../../core/plugins'
import type { ProviderId } from '../../core/providers/types'
import type { AiRequestContext } from '../../types'
import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../src/core/plugins/types'
/**
* Type for partial overrides that allows omitting the model field
@@ -60,7 +59,8 @@ export function createMockContext(overrides?: ContextOverrides): AiRequestContex
isRecursiveCall: false,
recursiveDepth: 0,
maxRecursiveDepth: 10,
extensions: new Map()
extensions: new Map(),
pluginState: {}
}
if (overrides) {

View File

@@ -0,0 +1,351 @@
/**
* Model Test Utilities
* Provides comprehensive mock creators for AI SDK v3 models and related test utilities
*/
import type {
EmbeddingModelV3,
ImageModelV3,
LanguageModelV3,
LanguageModelV3Middleware,
ProviderV3
} from '@ai-sdk/provider'
import type { Tool, ToolSet } from 'ai'
import { tool } from 'ai'
import { MockLanguageModelV3 } from 'ai/test'
import { vi } from 'vitest'
import * as z from 'zod'
import type { StreamTextParams, StreamTextResult } from '../../src/core/plugins'
import type { RegisteredProviderId } from '../../src/core/providers/types'
import type { AiRequestContext } from '../../src/types'
/**
* Type for partial overrides that allows omitting the model field
* The model will be automatically added by createMockContext
*/
type ContextOverrides = Partial<Omit<AiRequestContext<StreamTextParams, StreamTextResult>, 'originalParams'>> & {
originalParams?: Partial<Omit<StreamTextParams, 'model'>> & { model?: StreamTextParams['model'] }
}
/**
* Creates a mock AiRequestContext with type safety
* The model field is automatically added to originalParams if not provided
*
* @example
* ```ts
* const context = createMockContext({
* providerId: 'openai',
* metadata: { requestId: 'test-123' }
* })
* ```
*/
export function createMockContext(overrides?: ContextOverrides): AiRequestContext<StreamTextParams, StreamTextResult> {
const mockModel = new MockLanguageModelV3({
provider: 'test-provider',
modelId: 'test-model'
})
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
providerId: 'openai' as RegisteredProviderId,
model: mockModel,
originalParams: {
model: mockModel,
messages: [{ role: 'user', content: 'Test message' }]
} as StreamTextParams,
metadata: {},
startTime: Date.now(),
requestId: 'test-request-id',
recursiveCall: vi.fn(),
isRecursiveCall: false,
recursiveDepth: 0,
maxRecursiveDepth: 10,
extensions: new Map(),
pluginState: {}
}
if (overrides) {
// Ensure model is always present in originalParams
const mergedOriginalParams = {
...base.originalParams,
...overrides.originalParams,
model: overrides.originalParams?.model ?? mockModel
}
return {
...base,
...overrides,
originalParams: mergedOriginalParams as StreamTextParams
}
}
return base
}
/**
* Creates a mock embedding model with customizable behavior
* Compliant with AI SDK v3 specification
*
* @example
* ```ts
* const embeddingModel = createMockEmbeddingModel({
* provider: 'openai',
* modelId: 'text-embedding-3-small',
* maxEmbeddingsPerCall: 2048
* })
* ```
*/
export function createMockEmbeddingModel(overrides?: Partial<EmbeddingModelV3>): EmbeddingModelV3 {
return {
specificationVersion: 'v3',
provider: 'mock-provider',
modelId: 'mock-embedding-model',
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn().mockResolvedValue({
embeddings: [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0]
],
usage: {
inputTokens: 10,
totalTokens: 10
},
rawResponse: { headers: {} }
}),
...overrides
} as EmbeddingModelV3
}
/**
* Creates a complete mock ProviderV3 with all model types
* Useful for testing provider registration and management
*
* @example
* ```ts
* const provider = createMockProviderV3({
* provider: 'openai',
* languageModel: customLanguageModel,
* imageModel: customImageModel
* })
* ```
*/
export function createMockProviderV3(overrides?: {
provider?: string
languageModel?: (modelId: string) => LanguageModelV3
imageModel?: (modelId: string) => ImageModelV3
embeddingModel?: (modelId: string) => EmbeddingModelV3
}): ProviderV3 {
const defaultLanguageModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
defaultObjectGenerationMode: 'tool',
supportedUrls: {},
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30,
inputTokenDetails: {},
outputTokenDetails: {}
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield { type: 'text-delta', textDelta: 'Mock ' }
yield { type: 'text-delta', textDelta: 'streaming ' }
yield { type: 'text-delta', textDelta: 'response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: {
inputTokens: 10,
outputTokens: 15,
totalTokens: 25,
inputTokenDetails: {},
outputTokenDetails: {}
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
}) as LanguageModelV3
const defaultImageModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxImagesPerCall: undefined,
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
})
}) as ImageModelV3
const defaultEmbeddingModel = (modelId: string) =>
({
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
modelId,
maxEmbeddingsPerCall: 100,
supportsParallelCalls: true,
doEmbed: vi.fn().mockResolvedValue({
embeddings: [
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0]
],
usage: {
inputTokens: 10,
totalTokens: 10
},
rawResponse: { headers: {} }
})
}) as EmbeddingModelV3
return {
specificationVersion: 'v3',
provider: overrides?.provider ?? 'mock-provider',
languageModel: vi.fn(overrides?.languageModel ?? defaultLanguageModel),
imageModel: vi.fn(overrides?.imageModel ?? defaultImageModel),
embeddingModel: vi.fn(overrides?.embeddingModel ?? defaultEmbeddingModel)
} as ProviderV3
}
/**
* Creates a mock middleware for testing middleware chains
* Supports both generate and stream wrapping
*
* @example
* ```ts
* const middleware = createMockMiddleware({
* name: 'test-middleware'
* })
* ```
*/
export function createMockMiddleware(): LanguageModelV3Middleware {
return {
specificationVersion: 'v3',
wrapGenerate: vi.fn((doGenerate) => doGenerate),
wrapStream: vi.fn((doStream) => doStream)
}
}
/**
* Creates a type-safe function tool for testing using AI SDK's tool() function
*
* @example
* ```ts
* const weatherTool = createMockTool('getWeather', 'Get current weather')
* ```
*/
export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> {
return tool({
description: description || `Mock tool: ${name}`,
inputSchema: z.object({
value: z.string().optional()
}),
execute: vi.fn(async () => 'mock result')
})
}
/**
* Creates a provider-defined tool for testing
*/
export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } {
return {
type: 'provider' as const,
description: description || `Mock provider tool: ${name}`
}
}
/**
* Creates a ToolSet with multiple tools
*
* @example
* ```ts
* const tools = createMockToolSet({
* getWeather: 'function',
* searchDatabase: 'function',
* nativeSearch: 'provider'
* })
* ```
*/
export function createMockToolSet(tools: Record<string, 'function' | 'provider'>): ToolSet {
const toolSet: ToolSet = {}
for (const [name, type] of Object.entries(tools)) {
if (type === 'function') {
toolSet[name] = createMockTool(name)
} else {
toolSet[name] = createMockProviderTool(name) as Tool
}
}
return toolSet
}
/**
* Creates mock stream params for testing
*
* @example
* ```ts
* const params = createMockStreamParams({
* messages: [{ role: 'user', content: 'Custom message' }],
* temperature: 0.7
* })
* ```
*/
export function createMockStreamParams(overrides?: Partial<StreamTextParams>): StreamTextParams {
return {
messages: [{ role: 'user', content: 'Test message' }],
...overrides
} as StreamTextParams
}
/**
* Common mock model instances for quick testing
*/
export const mockModels = {
/** Standard language model for general testing */
language: new MockLanguageModelV3({
provider: 'test-provider',
modelId: 'test-model'
}),
/** Mock OpenAI GPT-4 model */
gpt4: new MockLanguageModelV3({
provider: 'openai',
modelId: 'gpt-4'
}),
/** Mock Anthropic Claude model */
claude: new MockLanguageModelV3({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
}),
/** Mock Google Gemini model */
gemini: new MockLanguageModelV3({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
} as const

View File

@@ -0,0 +1,13 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Mocks
export * from './mocks/providers'
export * from './mocks/responses'
// Helpers
export * from './helpers/common'
export * from './helpers/model'
export * from './helpers/provider'

View File

@@ -11,11 +11,15 @@
"noEmitOnError": false,
"outDir": "./dist",
"resolveJsonModule": true,
"rootDir": "./src",
"rootDir": ".",
"skipLibCheck": true,
"strict": true,
"target": "ES2020"
"target": "ES2020",
"paths": {
"@test-utils": ["./test_utils"],
"@test-utils/*": ["./test_utils/*"]
}
},
"exclude": ["node_modules", "dist"],
"include": ["src/**/*"]
"include": ["src/**/*", "test_utils/**/*"]
}

View File

@@ -8,13 +8,14 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url))
export default defineConfig({
test: {
globals: true,
setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
setupFiles: [path.resolve(__dirname, './test_utils/setup.ts')]
},
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
'@test-utils': path.resolve(__dirname, './test_utils'),
// Mock external packages that may not be available in test environment
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './test_utils/mocks/ai-sdk-provider.ts')
}
},
esbuild: {

View File

@@ -1,196 +0,0 @@
diff --git a/client.js b/client.js
index c2b9cd6e46f9f66f901af259661bc2d2f8b38936..9b6b3af1a6573e1ccaf3a1c5f41b48df198cbbe0 100644
--- a/client.js
+++ b/client.js
@@ -26,7 +26,7 @@ Object.defineProperty(exports, "__esModule", { value: true });
exports.AnthropicVertex = exports.BaseAnthropic = void 0;
const client_1 = require("@anthropic-ai/sdk/client");
const Resources = __importStar(require("@anthropic-ai/sdk/resources/index"));
-const google_auth_library_1 = require("google-auth-library");
+// const google_auth_library_1 = require("google-auth-library");
const env_1 = require("./internal/utils/env.js");
const values_1 = require("./internal/utils/values.js");
const headers_1 = require("./internal/headers.js");
@@ -56,7 +56,7 @@ class AnthropicVertex extends client_1.BaseAnthropic {
throw new Error('No region was given. The client should be instantiated with the `region` option or the `CLOUD_ML_REGION` environment variable should be set.');
}
super({
- baseURL: baseURL || `https://${region}-aiplatform.googleapis.com/v1`,
+ baseURL: baseURL || (region === 'global' ? 'https://aiplatform.googleapis.com/v1' : `https://${region}-aiplatform.googleapis.com/v1`),
...opts,
});
this.messages = makeMessagesResource(this);
@@ -64,22 +64,22 @@ class AnthropicVertex extends client_1.BaseAnthropic {
this.region = region;
this.projectId = projectId;
this.accessToken = opts.accessToken ?? null;
- this._auth =
- opts.googleAuth ?? new google_auth_library_1.GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
- this._authClientPromise = this._auth.getClient();
+ // this._auth =
+ // opts.googleAuth ?? new google_auth_library_1.GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
+ // this._authClientPromise = this._auth.getClient();
}
validateHeaders() {
// auth validation is handled in prepareOptions since it needs to be async
}
- async prepareOptions(options) {
- const authClient = await this._authClientPromise;
- const authHeaders = await authClient.getRequestHeaders();
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
- if (!this.projectId && projectId) {
- this.projectId = projectId;
- }
- options.headers = (0, headers_1.buildHeaders)([authHeaders, options.headers]);
- }
+ // async prepareOptions(options) {
+ // const authClient = await this._authClientPromise;
+ // const authHeaders = await authClient.getRequestHeaders();
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
+ // if (!this.projectId && projectId) {
+ // this.projectId = projectId;
+ // }
+ // options.headers = (0, headers_1.buildHeaders)([authHeaders, options.headers]);
+ // }
buildRequest(options) {
if ((0, values_1.isObj)(options.body)) {
// create a shallow copy of the request body so that code that mutates it later
diff --git a/client.mjs b/client.mjs
index 70274cbf38f69f87cbcca9567e77e4a7b938cf90..4dea954b6f4afad565663426b7adfad5de973a7d 100644
--- a/client.mjs
+++ b/client.mjs
@@ -1,6 +1,6 @@
import { BaseAnthropic } from '@anthropic-ai/sdk/client';
import * as Resources from '@anthropic-ai/sdk/resources/index';
-import { GoogleAuth } from 'google-auth-library';
+// import { GoogleAuth } from 'google-auth-library';
import { readEnv } from "./internal/utils/env.mjs";
import { isObj } from "./internal/utils/values.mjs";
import { buildHeaders } from "./internal/headers.mjs";
@@ -29,7 +29,7 @@ export class AnthropicVertex extends BaseAnthropic {
throw new Error('No region was given. The client should be instantiated with the `region` option or the `CLOUD_ML_REGION` environment variable should be set.');
}
super({
- baseURL: baseURL || `https://${region}-aiplatform.googleapis.com/v1`,
+ baseURL: baseURL || (region === 'global' ? 'https://aiplatform.googleapis.com/v1' : `https://${region}-aiplatform.googleapis.com/v1`),
...opts,
});
this.messages = makeMessagesResource(this);
@@ -37,22 +37,22 @@ export class AnthropicVertex extends BaseAnthropic {
this.region = region;
this.projectId = projectId;
this.accessToken = opts.accessToken ?? null;
- this._auth =
- opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
- this._authClientPromise = this._auth.getClient();
+ // this._auth =
+ // opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
+ //this._authClientPromise = this._auth.getClient();
}
validateHeaders() {
// auth validation is handled in prepareOptions since it needs to be async
}
- async prepareOptions(options) {
- const authClient = await this._authClientPromise;
- const authHeaders = await authClient.getRequestHeaders();
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
- if (!this.projectId && projectId) {
- this.projectId = projectId;
- }
- options.headers = buildHeaders([authHeaders, options.headers]);
- }
+ // async prepareOptions(options) {
+ // const authClient = await this._authClientPromise;
+ // const authHeaders = await authClient.getRequestHeaders();
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
+ // if (!this.projectId && projectId) {
+ // this.projectId = projectId;
+ // }
+ // options.headers = buildHeaders([authHeaders, options.headers]);
+ // }
buildRequest(options) {
if (isObj(options.body)) {
// create a shallow copy of the request body so that code that mutates it later
diff --git a/src/client.ts b/src/client.ts
index a6f9c6be65e4189f4f9601fb560df3f68e7563eb..37b1ad2802e3ca0dae4ca35f9dcb5b22dcf09796 100644
--- a/src/client.ts
+++ b/src/client.ts
@@ -12,22 +12,22 @@ export { BaseAnthropic } from '@anthropic-ai/sdk/client';
const DEFAULT_VERSION = 'vertex-2023-10-16';
const MODEL_ENDPOINTS = new Set<string>(['/v1/messages', '/v1/messages?beta=true']);
-export type ClientOptions = Omit<CoreClientOptions, 'apiKey' | 'authToken'> & {
- region?: string | null | undefined;
- projectId?: string | null | undefined;
- accessToken?: string | null | undefined;
-
- /**
- * Override the default google auth config using the
- * [google-auth-library](https://www.npmjs.com/package/google-auth-library) package.
- *
- * Note that you'll likely have to set `scopes`, e.g.
- * ```ts
- * new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' })
- * ```
- */
- googleAuth?: GoogleAuth | null | undefined;
-};
+// export type ClientOptions = Omit<CoreClientOptions, 'apiKey' | 'authToken'> & {
+// region?: string | null | undefined;
+// projectId?: string | null | undefined;
+// accessToken?: string | null | undefined;
+
+// /**
+// * Override the default google auth config using the
+// * [google-auth-library](https://www.npmjs.com/package/google-auth-library) package.
+// *
+// * Note that you'll likely have to set `scopes`, e.g.
+// * ```ts
+// * new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' })
+// * ```
+// */
+// googleAuth?: GoogleAuth | null | undefined;
+// };
export class AnthropicVertex extends BaseAnthropic {
region: string;
@@ -74,9 +74,9 @@ export class AnthropicVertex extends BaseAnthropic {
this.projectId = projectId;
this.accessToken = opts.accessToken ?? null;
- this._auth =
- opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
- this._authClientPromise = this._auth.getClient();
+ // this._auth =
+ // opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
+ // this._authClientPromise = this._auth.getClient();
}
messages: MessagesResource = makeMessagesResource(this);
@@ -86,17 +86,17 @@ export class AnthropicVertex extends BaseAnthropic {
// auth validation is handled in prepareOptions since it needs to be async
}
- protected override async prepareOptions(options: FinalRequestOptions): Promise<void> {
- const authClient = await this._authClientPromise;
+ // protected override async prepareOptions(options: FinalRequestOptions): Promise<void> {
+ // const authClient = await this._authClientPromise;
- const authHeaders = await authClient.getRequestHeaders();
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
- if (!this.projectId && projectId) {
- this.projectId = projectId;
- }
+ // const authHeaders = await authClient.getRequestHeaders();
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
+ // if (!this.projectId && projectId) {
+ // this.projectId = projectId;
+ // }
- options.headers = buildHeaders([authHeaders, options.headers]);
- }
+ // options.headers = buildHeaders([authHeaders, options.headers]);
+ // }
override buildRequest(options: FinalRequestOptions): {
req: FinalizedRequestInit;

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,154 @@
diff --git a/dist/index.d.mts b/dist/index.d.mts
index 00a528a923dfd113b43bbbc99fd4218121690e2e..825db3a81a0e7c3e109c2226a8c74e58f29c3427 100644
--- a/dist/index.d.mts
+++ b/dist/index.d.mts
@@ -4259,6 +4259,12 @@ declare function generateImage({ model: modelArg, prompt: promptArg, n, maxImage
Only applicable for HTTP-based providers.
*/
headers?: Record<string, string>;
+ /**
+ Custom download function to use for URLs.
+
+ By default, files are downloaded if the model returns URLs instead of binary data.
+ */
+ experimental_download?: DownloadFunction | undefined;
}): Promise<GenerateImageResult>;
/**
diff --git a/dist/index.d.ts b/dist/index.d.ts
index 00a528a923dfd113b43bbbc99fd4218121690e2e..825db3a81a0e7c3e109c2226a8c74e58f29c3427 100644
--- a/dist/index.d.ts
+++ b/dist/index.d.ts
@@ -4259,6 +4259,12 @@ declare function generateImage({ model: modelArg, prompt: promptArg, n, maxImage
Only applicable for HTTP-based providers.
*/
headers?: Record<string, string>;
+ /**
+ Custom download function to use for URLs.
+
+ By default, files are downloaded if the model returns URLs instead of binary data.
+ */
+ experimental_download?: DownloadFunction | undefined;
}): Promise<GenerateImageResult>;
/**
diff --git a/dist/index.js b/dist/index.js
index b72259eb460c25f479cf9087a5c3c2e39bacf060..4d43da4202089bff9d774894888463f137f09358 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -8237,7 +8237,8 @@ async function generateImage({
providerOptions,
maxRetries: maxRetriesArg,
abortSignal,
- headers
+ headers,
+ experimental_download: download2
}) {
var _a14, _b;
const model = resolveImageModel(modelArg);
@@ -8286,21 +8287,33 @@ async function generateImage({
outputTokens: void 0,
totalTokens: void 0
};
+ const downloadFn = download2 ?? createDefaultDownloadFunction();
for (const result of results) {
- images.push(
- ...result.images.map(
- (image) => {
- var _a15;
- return new DefaultGeneratedFile({
- data: image,
- mediaType: (_a15 = detectMediaType({
- data: image,
- signatures: imageMediaTypeSignatures
- })) != null ? _a15 : "image/png"
- });
+ const processedImages = await Promise.all(
+ result.images.map(async (image) => {
+ var _a15;
+ // 检查是否为 URL 字符串
+ if (typeof image === "string" && (image.startsWith("http://") || image.startsWith("https://"))) {
+ const downloaded = await downloadFn([{ url: new URL(image), isUrlSupportedByModel: false }]);
+ const downloadedData = downloaded[0];
+ if (downloadedData) {
+ return new DefaultGeneratedFile({
+ data: downloadedData.data,
+ mediaType: downloadedData.mediaType ?? "image/png"
+ });
+ }
}
- )
+ // 原有逻辑base64/Uint8Array
+ return new DefaultGeneratedFile({
+ data: image,
+ mediaType: (_a15 = detectMediaType({
+ data: image,
+ signatures: imageMediaTypeSignatures
+ })) != null ? _a15 : "image/png"
+ });
+ })
);
+ images.push(...processedImages);
warnings.push(...result.warnings);
if (result.usage != null) {
totalUsage = addImageModelUsage(totalUsage, result.usage);
diff --git a/dist/index.mjs b/dist/index.mjs
index a6538d517173782802369ef4f90dac2a8cf9d75a..0977129233232d8bfcfa3c4936f5e9addeb8ca4a 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -8177,7 +8177,8 @@ async function generateImage({
providerOptions,
maxRetries: maxRetriesArg,
abortSignal,
- headers
+ headers,
+ experimental_download: download2
}) {
var _a14, _b;
const model = resolveImageModel(modelArg);
@@ -8226,21 +8227,33 @@ async function generateImage({
outputTokens: void 0,
totalTokens: void 0
};
+ const downloadFn = download2 ?? createDefaultDownloadFunction();
for (const result of results) {
- images.push(
- ...result.images.map(
- (image) => {
- var _a15;
- return new DefaultGeneratedFile({
- data: image,
- mediaType: (_a15 = detectMediaType({
- data: image,
- signatures: imageMediaTypeSignatures
- })) != null ? _a15 : "image/png"
- });
+ const processedImages = await Promise.all(
+ result.images.map(async (image) => {
+ var _a15;
+ // 检查是否为 URL 字符串
+ if (typeof image === "string" && (image.startsWith("http://") || image.startsWith("https://"))) {
+ const downloaded = await downloadFn([{ url: new URL(image), isUrlSupportedByModel: false }]);
+ const downloadedData = downloaded[0];
+ if (downloadedData) {
+ return new DefaultGeneratedFile({
+ data: downloadedData.data,
+ mediaType: downloadedData.mediaType ?? "image/png"
+ });
+ }
}
- )
+ // 原有逻辑base64/Uint8Array
+ return new DefaultGeneratedFile({
+ data: image,
+ mediaType: (_a15 = detectMediaType({
+ data: image,
+ signatures: imageMediaTypeSignatures
+ })) != null ? _a15 : "image/png"
+ });
+ })
);
+ images.push(...processedImages);
warnings.push(...result.warnings);
if (result.usage != null) {
totalUsage = addImageModelUsage(totalUsage, result.usage);

1233
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,11 @@ import { SystemProviderIds } from '@types'
import { formatVertexApiHost } from './utils/api'
type HostFormatter = {
match: (provider: Provider) => boolean
format: (provider: Provider, appendApiVersion: boolean) => string | Promise<string>
}
/**
* Format and normalize the API host URL for a provider.
* Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting).
@@ -23,37 +28,40 @@ import { formatVertexApiHost } from './utils/api'
* @returns A new provider instance with the formatted API host.
*/
export async function formatProviderApiHost(provider: Provider): Promise<Provider> {
// WARNING: if any changes are made here, please sync it to src/renderer/src/aiCore/provider/providerConfig.ts:formatProviderApiHost
// NOTE: It's async to support Vertex API host formatting
const formatted = { ...provider }
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
}
// Anthropic is special: uses anthropicApiHost as source and syncs both fields
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isOllamaProvider(formatted)) {
formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = await formatVertexApiHost(formatted.apiHost)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion)
return formatted
}
const formatters: HostFormatter[] = [
{
match: (p) => p.id === SystemProviderIds.copilot || p.id === SystemProviderIds.github,
format: (p) => formatApiHost(p.apiHost, false)
},
{ match: isCherryAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
{ match: isPerplexityProvider, format: (p) => formatApiHost(p.apiHost, false) },
{ match: isOllamaProvider, format: (p) => formatOllamaApiHost(p.apiHost) },
{ match: isGeminiProvider, format: (p, av) => formatApiHost(p.apiHost, av, 'v1beta') },
{ match: isAzureOpenAIProvider, format: (p) => formatAzureOpenAIApiHost(p.apiHost) },
{ match: isVertexProvider, format: (p) => formatVertexApiHost(p.apiHost) }
]
const formatter = formatters.find((f) => f.match(provider))
formatted.apiHost = formatter
? await formatter.format(formatted, appendApiVersion)
: formatApiHost(formatted.apiHost, appendApiVersion)
return formatted
}

View File

@@ -671,9 +671,10 @@ class FileStorage {
const parseResult = parseDataUrl(base64Data)
const base64String = parseResult?.data ?? base64Data
const ext = parseResult?.mediaType ? this.getExtensionFromMimeType(parseResult.mediaType) : '.png'
const buffer = Buffer.from(base64String, 'base64')
const uuid = uuidv4()
const ext = '.png'
const destPath = path.join(this.storageDir, uuid + ext)
logger.debug('Saving base64 image:', {
@@ -1560,6 +1561,8 @@ class FileStorage {
'image/jpeg': '.jpg',
'image/png': '.png',
'image/gif': '.gif',
'image/webp': '.webp',
'image/bmp': '.bmp',
'application/pdf': '.pdf',
'text/plain': '.txt',
'application/msword': '.doc',

View File

@@ -0,0 +1,514 @@
import { createExecutor } from '@cherrystudio/ai-core'
import type { generateImageResult } from '@cherrystudio/ai-core/core/runtime/types'
import { loggerService } from '@logger'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { normalizeGatewayModels } from '@renderer/services/models/ModelAdapter'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import {
type Assistant,
type EditImageParams,
type GenerateImageParams,
type Model,
type Provider,
SystemProviderIds
} from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
import { gateway } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import { buildPlugins } from './plugins/PluginBuilder'
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
import { listModels } from './services/listModels'
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
const logger = loggerService.withContext('AiProvider')
export type AiProviderConfig = AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
export default class AiProvider {
private config?: ProviderConfig
private actualProvider: Provider
private model?: Model
/**
* Constructor for AiProvider
*
* @param modelOrProvider - Model or Provider object
* @param provider - Optional Provider object (only used when first param is Model)
*
* @remarks
* **Important behavior notes**:
*
* 1. When called with `(model)`:
* - Calls `getActualProvider(model)` to retrieve and format the provider
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
*
* 2. When called with `(model, provider)`:
* - The provided provider will be adapted via `adaptProvider`
* - URL formatting behavior depends on the adapted result
*
* 3. When called with `(provider)`:
* - The provider will be adapted via `adaptProvider`
* - Used for operations that don't need a model (e.g., fetchModels)
*
* @example
* ```typescript
* // Recommended: Auto-format URL
* const ai = new AiProvider(model)
*
* // Provider will be adapted
* const ai = new AiProvider(model, customProvider)
*
* // For operations that don't need a model
* const ai = new AiProvider(provider)
* ```
*/
constructor(model: Model, provider?: Provider)
constructor(provider: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
if (this.isModel(modelOrProvider)) {
// 传入的是 Model
this.model = modelOrProvider
this.actualProvider = provider
? adaptProvider({ provider, model: modelOrProvider })
: getActualProvider(modelOrProvider)
// 注意config 可能是同步值或 Promise在 completions() 中会统一处理
const configOrPromise = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
this.config = configOrPromise instanceof Promise ? undefined : configOrPromise
} else {
// 传入的是 Provider
this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选某些操作如fetchModels不需要model
}
}
/**
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
*/
private isModel(obj: Model | Provider): obj is Model {
return 'provider' in obj && typeof obj.provider === 'string'
}
public getActualProvider() {
return this.actualProvider
}
public async completions(modelId: string, params: StreamTextParams, middlewareConfig: AiProviderConfig) {
// 检查model是否存在
if (!this.model) {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
// Config is now set in constructor, ApiService handles key rotation before passing provider
if (!this.config) {
// If config wasn't set in constructor (when provider only), generate it now
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, this.model))
}
logger.debug('Using provider config for completions', this.config)
// 注意:模型对象将由 createExecutor 内部处理,不再需要预先创建
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
// 类型守卫:确保 system 是 string、Array 或 undefined
const system = params.system
let systemParam: string | Array<any> | undefined
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
systemParam = system
} else {
// SystemModelMessage 类型,转换为 string
systemParam = undefined
}
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
params.system = undefined // 清除原有system避免重复
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
}
if (middlewareConfig.topicId && getEnableDeveloperMode()) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...middlewareConfig,
topicId: middlewareConfig.topicId
}
return await this._completionsForTrace(modelId, params, traceConfig, this.config)
} else {
return await this.modernCompletions(modelId, params, middlewareConfig, this.config)
}
}
/**
* 带trace支持的completions方法
* 类似于legacy的completionsForTrace确保AI SDK spans在正确的trace上下文中
*/
private async _completionsForTrace(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiProviderConfig & { topicId: string },
providerConfig: ProviderConfig
): Promise<CompletionsResult> {
const traceName = `${this.actualProvider.name}.${modelId}.${middlewareConfig.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: middlewareConfig.topicId,
modelName: middlewareConfig.assistant.model?.name, // 使用modelId而不是provider名称
inputs: params
}
logger.info('Starting AI SDK trace span', {
traceName,
topicId: middlewareConfig.topicId,
modelId,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : []
})
const span = addSpan(traceParams)
if (!span) {
logger.warn('Failed to create span, falling back to regular completions', {
topicId: middlewareConfig.topicId,
modelId,
traceName
})
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
}
try {
logger.info('Created parent span, now calling completions', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId,
parentSpanCreated: true
})
const result = await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId,
resultLength: result.getText().length
})
// 标记span完成
endSpan({
topicId: middlewareConfig.topicId,
outputs: result,
span,
modelName: modelId // 使用modelId保持一致性
})
return result
} catch (error) {
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: middlewareConfig.topicId,
modelId
})
// 标记span出错
endSpan({
topicId: middlewareConfig.topicId,
error: error as Error,
span,
modelName: modelId // 使用modelId保持一致性
})
throw error
}
}
/**
* 使用现代化AI SDK的completions实现
*/
/**
* Note: This implementation always uses `executor.streamText` and never
* calls `generateText`, even when `onChunk` is not provided.
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiProviderConfig,
providerConfig: ProviderConfig
): Promise<CompletionsResult> {
const plugins = buildPlugins({
provider: this.actualProvider,
model: this.model!,
config: middlewareConfig
})
// 用构建好的插件数组创建executor
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
plugins
)
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) {
const accumulate = this.model!.supported_text_delta !== false // true and undefined
const adapter = new AiSdkToChunkAdapter(
middlewareConfig.onChunk,
middlewareConfig.mcpTools,
accumulate,
middlewareConfig.enableWebSearch,
undefined,
undefined,
providerConfig.providerId
)
const streamResult = await executor.streamText({
...params,
model: modelId,
experimental_context: { onChunk: middlewareConfig.onChunk }
})
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else {
// Since no onChunk is provided, the external consumer would not handle error chunk.
// So we need to capture the actual stream error so we can throw it instead of the
// generic NoTextGeneratedError ("No output generated. Check the stream
// for errors.") that AI SDK raises when streamResult.text is accessed
// after a failed stream.
let streamError: unknown = undefined
const streamResult = await executor.streamText({
...params,
model: modelId,
onError({ error }) {
streamError = error
}
})
// 强制消费流,不然await streamResult.text会阻塞
await streamResult?.consumeStream({
onError(error) {
if (!streamError) {
streamError = error
}
}
})
try {
const finalText = await streamResult.text
const usage = await streamResult.totalUsage
return {
getText: () => finalText,
usage
}
} catch (error) {
// If we captured the real stream error, throw that instead of the
// generic NoTextGeneratedError so callers get actionable diagnostics.
if (streamError) {
throw streamError
}
throw error
}
}
}
/**
* 获取模型列表
* 使用 ModelListService 统一处理各 Provider 的模型列表获取
*/
public async models(): Promise<Model[]> {
// Gateway provider 使用 AI SDK 的 gateway API
if (this.actualProvider.id === SystemProviderIds.gateway) {
const gatewayModels = (await gateway.getAvailableModels()).models
return normalizeGatewayModels(this.actualProvider, gatewayModels)
}
// 使用新的 ModelListService
return await listModels(this.actualProvider)
}
/**
* 获取嵌入模型的维度
* 使用 AI SDK embedMany 测试获取维度
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
// 确保 config 已定义
if (!this.config) {
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
}
const executor = await createExecutor<AppProviderSettingsMap>(
this.config.providerId,
this.config.providerSettings,
[]
)
// 使用 AI SDK embedMany 测试获取维度
const result = await executor.embedMany({
model: model.id,
values: ['test']
})
return result.embeddings[0].length
}
/**
* 懒加载初始化 config
* 当 constructor 只传入 provider 时config 不会被初始化
* 此方法根据 modelId 从 provider 的 models 中查找真实 Model 并生成 config
*/
private async ensureConfig(modelId: string): Promise<void> {
if (this.config) {
return
}
// 从 provider 的 models 中查找真实的 model
const model = this.actualProvider.models.find((m) => getLowerBaseModelName(m.id) === getLowerBaseModelName(modelId))
if (!model) {
throw new Error(`Model "${modelId}" not found in provider "${this.actualProvider.id}"`)
}
this.actualProvider = adaptProvider({ provider: this.actualProvider, model })
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
}
/**
* 生成图像
* 使用现代化 AI SDK 实现,不再 fallback 到 legacy
*/
public async generateImage(params: GenerateImageParams): Promise<string[]> {
await this.ensureConfig(params.model)
return await this.modernGenerateImage(params, this.config!)
}
/**
* 编辑图像 - 基于输入图像和文本提示生成新图像
* 内部使用 AI SDK 的 generateImage通过 prompt.images 参数实现编辑功能
*/
public async editImage(params: EditImageParams): Promise<string[]> {
await this.ensureConfig(params.model)
return await this.modernEditImage(params, this.config!)
}
/**
* 使用现代化 AI SDK 的图像生成实现
*/
private async modernGenerateImage(params: GenerateImageParams, providerConfig: ProviderConfig): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
[]
)
const result = await executor.generateImage({
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
...aiSdkParams
})
return this.convertImageResult(result)
}
/**
* 使用现代化 AI SDK 的图像编辑实现
* 通过 AI SDK 的 generateImage 并传入 prompt.images 参数实现编辑功能
*/
private async modernEditImage(params: EditImageParams, providerConfig: ProviderConfig): Promise<string[]> {
const { model, prompt, inputImages, mask, imageSize, signal } = params
const executor = await createExecutor<AppProviderSettingsMap>(
providerConfig.providerId,
providerConfig.providerSettings,
[]
)
// 使用 AI SDK 的 generateImage通过 prompt.images 实现编辑
const result = await executor.generateImage({
model: model,
prompt: {
text: prompt,
images: inputImages, // 输入图像(必需)
...(mask && { mask }) // 可选的 mask用于 inpainting
},
size: (imageSize || '1024x1024') as `${number}x${number}`,
...(signal && { abortSignal: signal })
})
return this.convertImageResult(result)
}
/**
* 转换图像生成结果格式
*/
private convertImageResult(result: generateImageResult): string[] {
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if (image.base64) {
images.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
}
}
}
return images
}
public getBaseURL(): string {
return this.actualProvider.apiHost || ''
}
public getApiKey(): string {
const apiKey = this.actualProvider.apiKey
if (!apiKey || apiKey.trim() === '') {
return ''
}
const keys = apiKey
.split(',')
.map((key) => key.trim())
.filter(Boolean)
if (keys.length === 0) {
return ''
}
if (keys.length === 1) {
return keys[0]
}
// Multi-key rotation
const keyName = `provider:${this.actualProvider.id}:last_used_key`
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
}

View File

@@ -306,25 +306,6 @@ export class AiSdkToChunkAdapter {
this.toolCallHandler.handleToolResult(chunk)
break
// === 步骤相关事件 ===
// case 'start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// case 'start-step':
// this.onChunk({
// type: ChunkType.BLOCK_CREATED
// })
// break
// case 'step-finish':
// this.onChunk({
// type: ChunkType.TEXT_COMPLETE,
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
// })
// final.text = ''
// break
case 'finish-step': {
const { providerMetadata, finishReason } = chunk
// googel web search

View File

@@ -1,16 +1 @@
/**
* Cherry Studio AI Core - 统一入口点
*
* 这是新的统一入口,保持向后兼容性
* 默认导出legacy AiProvider以保持现有代码的兼容性
*/
// 导出Legacy AiProvider作为默认导出保持向后兼容
export { default } from './legacy/index'
// 同时导出Modern AiProvider供新代码使用
export { default as ModernAiProvider } from './index_new'
// 导出一些常用的类型和工具
export * from './legacy/clients/types'
export * from './legacy/middleware/schemas'
export { default as AiProvider, type AiProviderConfig } from './AiProvider'

View File

@@ -1,609 +0,0 @@
/**
* Cherry Studio AI Core - 新版本入口
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
*
* 融合方案:简化实现,专注于核心功能
* 1. 优先使用新AI SDK
* 2. 暂时保持接口兼容性
*/
import type { AiSdkModel } from '@cherrystudio/ai-core'
import { createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils'
import type { IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
import { gateway, type LanguageModel, type Provider as AiSdkProvider } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
import {
adaptProvider,
getActualProvider,
isModernSdkSupported,
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
const logger = loggerService.withContext('ModernAiProvider')
export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
idleTimeout?: IdleTimeoutHandle
}
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
/**
* Constructor for ModernAiProvider
*
* @param modelOrProvider - Model or Provider object
* @param provider - Optional Provider object (only used when first param is Model)
*
* @remarks
* **Important behavior notes**:
*
* 1. When called with `(model)`:
* - Calls `getActualProvider(model)` to retrieve and format the provider
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
*
* 2. When called with `(model, provider)`:
* - The provided provider will be adapted via `adaptProvider`
* - URL formatting behavior depends on the adapted result
*
* 3. When called with `(provider)`:
* - The provider will be adapted via `adaptProvider`
* - Used for operations that don't need a model (e.g., fetchModels)
*
* @example
* ```typescript
* // Recommended: Auto-format URL
* const ai = new ModernAiProvider(model)
*
* // Provider will be adapted
* const ai = new ModernAiProvider(model, customProvider)
*
* // For operations that don't need a model
* const ai = new ModernAiProvider(provider)
* ```
*/
constructor(model: Model, provider?: Provider)
constructor(provider: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider)
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
if (this.isModel(modelOrProvider)) {
// 传入的是 Model
this.model = modelOrProvider
this.actualProvider = provider
? adaptProvider({ provider, model: modelOrProvider })
: getActualProvider(modelOrProvider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
} else {
// 传入的是 Provider
this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选某些操作如fetchModels不需要model
}
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
}
/**
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
*/
private isModel(obj: Model | Provider): obj is Model {
return 'provider' in obj && typeof obj.provider === 'string'
}
public getActualProvider() {
return this.actualProvider
}
/**
* Note: This method routes text completions through `modernCompletions`,
* which only calls `streamText` (no `generateText` path).
*/
public async completions(modelId: string, params: StreamTextParams, providerConfig: ModernAiProviderConfig) {
// 检查model是否存在
if (!this.model) {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
// Config is now set in constructor, ApiService handles key rotation before passing provider
if (!this.config) {
// If config wasn't set in constructor (when provider only), generate it now
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
}
logger.debug('Using provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
// 准备特殊配置
await prepareSpecialProviderConfig(this.actualProvider, this.config)
// 提前创建本地 provider 实例
if (!this.localProvider) {
this.localProvider = await createAiSdkProvider(this.config)
}
if (!this.localProvider) {
throw new Error('Local provider not created')
}
// 根据endpoint类型创建对应的模型
let model: AiSdkModel | undefined
if (providerConfig.isImageGenerationEndpoint) {
model = this.localProvider.imageModel(modelId)
} else {
model = this.localProvider.languageModel(modelId)
}
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
// 类型守卫:确保 system 是 string、Array 或 undefined
const system = params.system
let systemParam: string | Array<any> | undefined
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
systemParam = system
} else {
// SystemModelMessage 类型,转换为 string
systemParam = undefined
}
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
params.system = undefined // 清除原有system避免重复
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
}
if (providerConfig.topicId && getEnableDeveloperMode()) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...providerConfig,
topicId: providerConfig.topicId
}
return await this._completionsForTrace(model, params, traceConfig)
} else {
return await this._completionsOrImageGeneration(model, params, providerConfig)
}
}
private async _completionsOrImageGeneration(
model: AiSdkModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
}
const legacyParams: CompletionsParams = {
callType: 'chat',
messages: config.uiMessages, // 使用原始的 UI 消息格式
assistant: config.assistant,
streamOutput: config.streamOutput ?? true,
onChunk: config.onChunk,
topicId: config.topicId,
mcpTools: config.mcpTools,
enableWebSearch: config.enableWebSearch
}
// 调用 legacy 的 completions会自动使用 ImageGenerationMiddleware
return await this.legacyProvider.completions(legacyParams)
}
return await this.modernCompletions(model as LanguageModel, params, config)
}
/**
* 带trace支持的completions方法
* 类似于legacy的completionsForTrace确保AI SDK spans在正确的trace上下文中
*/
private async _completionsForTrace(
model: AiSdkModel,
params: StreamTextParams,
config: ModernAiProviderConfig & { topicId: string }
): Promise<CompletionsResult> {
const modelId = this.model!.id
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: config.topicId,
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
inputs: params
}
logger.info('Starting AI SDK trace span', {
traceName,
topicId: config.topicId,
modelId,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : [],
isImageGeneration: config.isImageGenerationEndpoint
})
const span = addSpan(traceParams)
if (!span) {
logger.warn('Failed to create span, falling back to regular completions', {
topicId: config.topicId,
modelId,
traceName
})
return await this._completionsOrImageGeneration(model, params, config)
}
try {
logger.info('Created parent span, now calling completions', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId,
parentSpanCreated: true
})
const result = await this._completionsOrImageGeneration(model, params, config)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId,
resultLength: result.getText().length
})
// 标记span完成
endSpan({
topicId: config.topicId,
outputs: result,
span,
modelName: modelId // 使用modelId保持一致性
})
return result
} catch (error) {
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId
})
// 标记span出错
endSpan({
topicId: config.topicId,
error: error as Error,
span,
modelName: modelId // 使用modelId保持一致性
})
throw error
}
}
/**
* 使用现代化AI SDK的completions实现
*/
/**
* Note: This implementation always uses `executor.streamText` and never
* calls `generateText`, even when `onChunk` is not provided.
*/
private async modernCompletions(
model: LanguageModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
// 根据条件构建插件数组
const plugins = buildPlugins({
provider: this.actualProvider,
model: this.model!,
config
})
// 用构建好的插件数组创建executor
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
// 创建带有中间件的执行器
if (config.onChunk) {
const accumulate = this.model!.supported_text_delta !== false // true and undefined
const adapter = new AiSdkToChunkAdapter(
config.onChunk,
config.mcpTools,
accumulate,
config.enableWebSearch,
undefined,
undefined,
this.config!.providerId,
config.idleTimeout
)
const streamResult = await executor.streamText({
...params,
model,
experimental_context: { onChunk: config.onChunk }
})
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else {
// Since no onChunk is provided, the external consumer would not handle error chunk.
// So we need to capture the actual stream error so we can throw it instead of the
// generic NoTextGeneratedError ("No output generated. Check the stream
// for errors.") that AI SDK raises when streamResult.text is accessed
// after a failed stream.
let streamError: unknown = undefined
const streamResult = await executor.streamText({
...params,
model,
onError({ error }) {
streamError = error
}
})
// 强制消费流,不然await streamResult.text会阻塞
await streamResult?.consumeStream({
onError(error) {
if (!streamError) {
streamError = error
}
}
})
try {
const finalText = await streamResult.text
const usage = await streamResult.totalUsage
return {
getText: () => finalText,
usage
}
} catch (error) {
// If we captured the real stream error, throw that instead of the
// generic NoTextGeneratedError so callers get actionable diagnostics.
if (streamError) {
throw streamError
}
throw error
}
}
}
// /**
// * 使用现代化 AI SDK 的图像生成实现,支持流式输出
// * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
// */
/*
private async modernImageGeneration(
model: ImageModel,
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
const { onChunk } = config
try {
// 检查 messages 是否存在
if (!params.messages || params.messages.length === 0) {
throw new Error('No messages provided for image generation.')
}
// 从最后一条用户消息中提取 prompt
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
// 直接使用消息内容,避免类型转换问题
const prompt =
typeof lastUserMessage.content === 'string'
? lastUserMessage.content
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
if (!prompt) {
throw new Error('No prompt found in user message.')
}
const startTime = Date.now()
// 发送图像生成开始事件
if (onChunk) {
onChunk({ type: ChunkType.IMAGE_CREATED })
}
// 构建图像生成参数
const imageParams = {
prompt,
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
n: 1,
...(params.abortSignal && { abortSignal: params.abortSignal })
}
// 调用新 AI SDK 的图像生成功能
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
const result = await executor.generateImage({
model,
...imageParams
})
// 转换结果格式
const images: string[] = []
const imageType: 'url' | 'base64' = 'base64'
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:${image.mediaType};base64,${image.base64}`)
}
}
}
// 发送图像生成完成事件
if (onChunk && images.length > 0) {
onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images }
})
}
// 发送块完成事件(类似于 modernCompletions 的处理)
if (onChunk) {
const usage = {
prompt_tokens: prompt.length, // 估算的 token 数量
completion_tokens: 0, // 图像生成没有 completion tokens
total_tokens: prompt.length
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
// 发送 LLM 响应完成事件
onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
}
return {
getText: () => '' // 图像生成不返回文本
}
} catch (error) {
// 发送错误事件
if (onChunk) {
onChunk({ type: ChunkType.ERROR, error: error as any })
}
throw error
}
}
*/
// 代理其他方法到原有实现
public async models() {
if (this.actualProvider.id === SystemProviderIds.gateway) {
const gatewayModels = (await gateway.getAvailableModels()).models
return normalizeGatewayModels(this.actualProvider, gatewayModels)
}
const sdkModels = await this.legacyProvider.models()
return normalizeSdkModels(this.actualProvider, sdkModels)
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建
if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) {
throw new Error('Local provider not created')
}
}
const result = await this.modernGenerateImage(params)
return result
} catch (error) {
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
// fallback 到传统实现
return this.legacyProvider.generateImage(params)
}
}
// 直接使用传统实现
return this.legacyProvider.generateImage(params)
}
/**
* 使用现代化 AI SDK 的图像生成实现
*/
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
const result = await executor.generateImage({
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
...aiSdkParams
})
// 转换结果格式
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:image/png;base64,${image.base64}`)
}
}
}
return images
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
}
}
// 为了方便调试,导出一些工具函数
export { isModernSdkSupported, providerToAiSdkConfig }

View File

@@ -1,110 +0,0 @@
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient'
import type { BaseApiClient } from './BaseApiClient'
import { CherryAiAPIClient } from './cherryai/CherryAiAPIClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { VertexAPIClient } from './gemini/VertexAPIClient'
import { NewAPIClient } from './newapi/NewAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { OVMSClient } from './ovms/OVMSClient'
import { PoeAPIClient } from './poe/PoeAPIClient'
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient'
const logger = loggerService.withContext('ApiClientFactory')
/**
* Factory for creating ApiClient instances based on provider configuration
* 根据提供者配置创建ApiClient实例的工厂
*/
export class ApiClientFactory {
/**
* Create an ApiClient instance for the given provider
* 为给定的提供者创建ApiClient实例
*/
static create(provider: Provider): BaseApiClient {
logger.debug(`Creating ApiClient for provider:`, {
id: provider.id,
type: provider.type
})
let instance: BaseApiClient
// 首先检查特殊的 Provider ID
if (provider.id === 'cherryai') {
instance = new CherryAiAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'aihubmix') {
logger.debug(`Creating AihubmixAPIClient for provider: ${provider.id}`)
instance = new AihubmixAPIClient(provider) as BaseApiClient
return instance
}
if (isNewApiProvider(provider)) {
logger.debug(`Creating NewAPIClient for provider: ${provider.id}`)
instance = new NewAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ppio') {
logger.debug(`Creating PPIOAPIClient for provider: ${provider.id}`)
instance = new PPIOAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'zhipu') {
instance = new ZhipuAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ovms') {
logger.debug(`Creating OVMSClient for provider: ${provider.id}`)
instance = new OVMSClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'poe') {
logger.debug(`Creating PoeAPIClient for provider: ${provider.id}`)
instance = new PoeAPIClient(provider) as BaseApiClient
return instance
}
// 然后检查标准的 Provider Type
switch (provider.type) {
case 'openai':
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
case 'azure-openai':
case 'openai-response':
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
break
case 'gemini':
instance = new GeminiAPIClient(provider) as BaseApiClient
break
case 'vertexai':
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
instance = new VertexAPIClient(provider) as BaseApiClient
break
case 'anthropic':
instance = new AnthropicAPIClient(provider) as BaseApiClient
break
case 'aws-bedrock':
instance = new AwsBedrockAPIClient(provider) as BaseApiClient
break
default:
logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
}
return instance
}
}

View File

@@ -1,489 +0,0 @@
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isFunctionCallingModel,
isOpenAIModel,
isSupportFlexServiceTierModel,
isSupportTemperatureModel,
isSupportTopPModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type { RootState } from '@renderer/store'
import type {
Assistant,
GenerateImageParams,
KnowledgeReference,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
MemoryItem,
Model,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
WebSearchResponse
} from '@renderer/types'
import {
FILE_TYPE,
GroqServiceTiers,
isGroqServiceTier,
isOpenAIServiceTier,
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
import type { CompletionsContext } from '../middleware/types'
import type { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
const logger = loggerService.withContext('BaseApiClient')
/**
* Abstract base class for API clients.
* Provides common functionality and structure for specific client implementations.
*/
export abstract class BaseApiClient<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
{
public provider: Provider
protected host: string
protected sdkInstance?: TSdkInstance
constructor(provider: Provider) {
this.provider = provider
this.host = this.getBaseURL()
}
/**
* Get the current API key with rotation support
* This getter ensures API keys rotate on each access when multiple keys are configured
*/
protected get apiKey(): string {
return this.getApiKey()
}
/**
* 获取客户端的兼容性类型
* 用于判断客户端是否支持特定功能避免instanceof检查的类型收窄问题
* 对于装饰器模式的客户端如AihubmixAPIClient应该返回其内部实际使用的客户端类型
*/
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
public getClientCompatibilityType(_model?: Model): string[] {
// 默认返回类的名称
return [this.constructor.name]
}
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
/**
* 核心API Endpoint
**/
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
abstract getEmbeddingDimensions(model?: Model): Promise<number>
abstract listModels(): Promise<SdkModel[]>
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
/**
* 中间件
**/
// 在 CoreRequestToSdkParamsMiddleware中使用
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
// 在RawSdkChunkToGenericChunkMiddleware中使用
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
/**
* 工具转换
**/
// Optional tool conversion methods - implement if needed by the specific provider
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
abstract buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string | undefined,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
abstract estimateMessageTokens(message: TMessageParam): number
abstract convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): TMessageParam | undefined
/**
* 从SDK载荷中提取消息数组用于中间件中的类型安全访问
* 不同的提供商可能使用不同的字段名如messages、history等
*/
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
/**
* 通用函数
**/
public getBaseURL(): string {
return this.provider.apiHost
}
public getApiKey() {
const keys = this.provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${this.provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
public defaultHeaders() {
return {
...defaultAppHeaders(),
'X-Api-Key': this.apiKey
}
}
public get keepAliveTime() {
return this.provider.id === 'lmstudio' ? getLMStudioKeepAliveTime() : undefined
}
public getTemperature(assistant: Assistant, model: Model): number | undefined {
if (!isSupportTemperatureModel(model)) {
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
}
public getTopP(assistant: Assistant, model: Model): number | undefined {
if (!isSupportTopPModel(model)) {
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined
}
// NOTE: 这个也许可以迁移到OpenAIBaseClient
protected getServiceTier(model: Model) {
const serviceTierSetting = this.provider.serviceTier
if (!isSupportServiceTierProvider(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) {
return undefined
}
// 处理不同供应商需要 fallback 到默认值的情况
if (this.provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
}
return serviceTierSetting
}
protected getVerbosity(model?: Model): OpenAIVerbosity {
try {
const state = window.store?.getState() as RootState
const verbosity = state?.settings?.openAI?.verbosity
// If model is provided, check if the verbosity is supported by the model
if (model) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
}
return verbosity
} catch (error) {
logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error)
return undefined
}
}
protected getTimeout(model: Model) {
if (isSupportFlexServiceTierModel(model)) {
return 15 * 1000 * 60
}
return DEFAULT_TIMEOUT
}
public async getMessageContent(
message: Message
): Promise<{ textContent: string; imageContents: { fileId: string; fileExt: string }[] }> {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return {
textContent: '',
imageContents: []
}
}
const webSearchReferences = await this.getWebSearchReferencesFromCache(message)
const knowledgeReferences = await this.getKnowledgeBaseReferencesFromCache(message)
const memoryReferences = this.getMemoryReferencesFromCache(message)
const knowledgeTextReferences = knowledgeReferences.filter((k) => k.metadata?.type !== 'image')
const knowledgeImageReferences = knowledgeReferences.filter((k) => k.metadata?.type === 'image')
// 添加偏移量以避免ID冲突
const reindexedKnowledgeReferences = knowledgeTextReferences.map((ref) => ({
...ref,
id: ref.id + webSearchReferences.length // 为知识库引用的ID添加网络搜索引用的数量作为偏移量
}))
const allReferences = [...webSearchReferences, ...reindexedKnowledgeReferences, ...memoryReferences]
logger.debug(`Found ${allReferences.length} references for ID: ${message.id}`, allReferences)
const referenceContent = `\`\`\`json\n${JSON.stringify(allReferences, null, 2)}\n\`\`\``
const imageReferences = knowledgeImageReferences.map((r) => {
return { fileId: r.metadata?.id, fileExt: r.metadata?.ext }
})
return {
textContent: isEmpty(allReferences)
? content
: REFERENCE_PROMPT.replace('{question}', content).replace('{references}', referenceContent),
imageContents: isEmpty(knowledgeImageReferences) ? [] : imageReferences
}
}
/**
* Extract the file content from the message
* @param message - The message
* @returns The file content
*/
protected async extractFileContent(message: Message) {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => fb.file.type === type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
private getMemoryReferencesFromCache(message: Message) {
const memories = window.keyv.get(`memory-search-${message.id}`) as MemoryItem[] | undefined
if (memories) {
const memoryReferences: KnowledgeReference[] = memories.map((mem, index) => ({
id: index + 1,
content: `${mem.memory} -- Created at: ${mem.createdAt}`,
sourceUrl: '',
type: 'memory'
}))
return memoryReferences
}
return []
}
private async getWebSearchReferencesFromCache(message: Message) {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return []
}
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
if (webSearch) {
window.keyv.remove(`web-search-${message.id}`)
return (webSearch.results as WebSearchProviderResponse).results.map(
(result, index) =>
({
id: index + 1,
content: result.content,
sourceUrl: result.url,
type: 'url'
}) as KnowledgeReference
)
}
return []
}
/**
* 从缓存中获取知识库引用
*/
private async getKnowledgeBaseReferencesFromCache(message: Message): Promise<KnowledgeReference[]> {
const content = getMainTextContent(message)
if (isEmpty(content)) {
return []
}
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
if (!isEmpty(knowledgeReferences)) {
window.keyv.remove(`knowledge-search-${message.id}`)
logger.debug(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
return knowledgeReferences
}
logger.debug(`No knowledge base references found in cache for ID: ${message.id}`)
return []
}
protected getCustomParameters(assistant: Assistant) {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
// Parse JSON type parameters (Legacy API clients)
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
// The UI stores JSON type params as strings, this function parses them before sending to API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
return { ...acc, [param.name]: isJSON(value) ? parseJSON(value) : value }
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
const abortController = new AbortController()
const abortFn = () => abortController.abort()
if (messageId) {
addAbortController(messageId, abortFn)
}
const cleanup = () => {
if (messageId) {
signalPromise.resolve?.(undefined)
removeAbortController(messageId, abortFn)
}
}
const signalPromise: {
resolve: (value: unknown) => void
promise: Promise<unknown>
} = {
resolve: () => {},
promise: Promise.resolve()
}
if (isAddEventListener) {
signalPromise.promise = new Promise((resolve, reject) => {
signalPromise.resolve = resolve
if (abortController.signal.aborted) {
reject(new Error('Request was aborted.'))
}
// 捕获abort事件,有些abort事件必须
abortController.signal.addEventListener('abort', () => {
reject(new Error('Request was aborted.'))
})
})
return {
abortController,
cleanup,
signalPromise
}
}
return {
abortController,
cleanup
}
}
// Setup tools configuration based on provided parameters
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: TSdkSpecificTool[]
} {
const { mcpTools, model, enableToolUse } = params
let tools: TSdkSpecificTool[] = []
// If there are no tools, return an empty array
if (!mcpTools?.length) {
return { tools }
}
// If the model supports function calling and tool usage is enabled
if (isFunctionCallingModel(model) && enableToolUse) {
tools = this.convertMcpToolsToSdkTools(mcpTools)
}
return { tools }
}
}

View File

@@ -1,181 +0,0 @@
import type {
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import type {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import type { CompletionsContext } from '../middleware/types'
import type { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import type { GeminiAPIClient } from './gemini/GeminiAPIClient'
import type { OpenAIAPIClient } from './openai/OpenAIApiClient'
import type { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import type { RequestTransformer, ResponseChunkTransformer } from './types'
/**
* MixedAPIClient - 适用于可能含有多种接口类型的Provider
*/
export abstract class MixedBaseAPIClient extends BaseApiClient {
// 使用联合类型而不是any保持类型安全
protected abstract clients: Map<
string,
AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient
>
protected abstract defaultClient: OpenAIAPIClient
protected abstract currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* 类型守卫确保client是BaseApiClient的实例
*/
protected isValidClient(client: unknown): client is BaseApiClient {
return (
client !== null &&
client !== undefined &&
typeof client === 'object' &&
'createCompletions' in client &&
'getRequestTransformer' in client &&
'getResponseChunkTransformer' in client
)
}
/**
* 根据模型获取合适的client
*/
protected abstract getClient(model: Model): BaseApiClient
/**
* 根据模型选择合适的client并委托调用
*/
public getClientForModel(model: Model): BaseApiClient {
this.currentClient = this.getClient(model)
return this.currentClient
}
/**
* 重写基类方法,返回内部实际使用的客户端类型
*/
public override getClientCompatibilityType(model?: Model): string[] {
if (!model) {
return [this.constructor.name]
}
const actualClient = this.getClient(model)
return actualClient.getClientCompatibilityType(model)
}
/**
* 从SDK payload中提取模型ID
*/
protected extractModelFromPayload(payload: SdkParams): string | null {
// 不同的SDK可能有不同的字段名
if ('model' in payload && typeof payload.model === 'string') {
return payload.model
}
return null
}
// ============ BaseApiClient 的抽象方法 ============
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
// 尝试从payload中提取模型信息来选择client
const modelId = this.extractModelFromPayload(payload)
if (modelId) {
const modelObj = { id: modelId } as Model
const targetClient = this.getClient(modelObj)
return targetClient.createCompletions(payload, options)
}
// 如果无法从payload中提取模型使用当前设置的client
return this.currentClient.createCompletions(payload, options)
}
async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.currentClient.generateImage(params)
}
async getEmbeddingDimensions(model?: Model): Promise<number> {
const client = model ? this.getClient(model) : this.currentClient
return client.getEmbeddingDimensions(model)
}
async listModels(): Promise<SdkModel[]> {
// 可以聚合所有client的模型或者使用默认client
return this.defaultClient.listModels()
}
async getSdkInstance(): Promise<SdkInstance> {
return this.currentClient.getSdkInstance()
}
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer(ctx)
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
}
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
}
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
}
buildSdkMessages(
currentReqMessages: SdkMessageParam[],
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls?: SdkToolCall[]
): SdkMessageParam[] {
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
}
estimateMessageTokens(message: SdkMessageParam): number {
return this.currentClient.estimateMessageTokens(message)
}
convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): SdkMessageParam | undefined {
const client = this.getClient(model)
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
}
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
}
}

View File

@@ -1,220 +0,0 @@
import type { Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '../ApiClientFactory'
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { VertexAPIClient } from '../gemini/VertexAPIClient'
import { NewAPIClient } from '../newapi/NewAPIClient'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
import { PPIOAPIClient } from '../ppio/PPIOAPIClient'
// 为工厂测试创建最小化 provider 的辅助函数
// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端
// 其他字段会传递给客户端构造函数,但不影响工厂逻辑
const createTestProvider = (id: string, type: string): Provider => ({
id,
type: type as Provider['type'],
name: '',
apiKey: '',
apiHost: '',
models: []
})
// Mock 所有客户端模块
vi.mock('../aihubmix/AihubmixAPIClient', () => ({
AihubmixAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../anthropic/AnthropicAPIClient', () => ({
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../anthropic/AnthropicVertexClient', () => ({
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../gemini/GeminiAPIClient', () => ({
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../gemini/VertexAPIClient', () => ({
VertexAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../newapi/NewAPIClient', () => ({
NewAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../openai/OpenAIApiClient', () => ({
OpenAIAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../openai/OpenAIResponseAPIClient', () => ({
OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({
getClient: vi.fn().mockReturnThis()
}))
}))
vi.mock('../ppio/PPIOAPIClient', () => ({
PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []
},
isOpenAIModel: vi.fn(() => false),
qwenModel: {}
}))
describe('ApiClientFactory', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('create', () => {
// 测试特殊 ID 的客户端创建
it('should create AihubmixAPIClient for aihubmix provider', () => {
const provider = createTestProvider('aihubmix', 'openai')
const client = ApiClientFactory.create(provider)
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create NewAPIClient for new-api provider', () => {
const provider = createTestProvider('new-api', 'openai')
const client = ApiClientFactory.create(provider)
expect(NewAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create PPIOAPIClient for ppio provider', () => {
const provider = createTestProvider('ppio', 'openai')
const client = ApiClientFactory.create(provider)
expect(PPIOAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试标准类型的客户端创建
it('should create OpenAIAPIClient for openai type', () => {
const provider = createTestProvider('custom-openai', 'openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create OpenAIResponseAPIClient for azure-openai type', () => {
const provider = createTestProvider('azure-openai', 'azure-openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create OpenAIResponseAPIClient for openai-response type', () => {
const provider = createTestProvider('response', 'openai-response')
const client = ApiClientFactory.create(provider)
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create GeminiAPIClient for gemini type', () => {
const provider = createTestProvider('gemini', 'gemini')
const client = ApiClientFactory.create(provider)
expect(GeminiAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create VertexAPIClient for vertexai type', () => {
const provider = createTestProvider('vertex', 'vertexai')
const client = ApiClientFactory.create(provider)
expect(VertexAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create AnthropicAPIClient for anthropic type', () => {
const provider = createTestProvider('anthropic', 'anthropic')
const client = ApiClientFactory.create(provider)
expect(AnthropicAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
const client = ApiClientFactory.create(provider)
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试默认情况
it('should create OpenAIAPIClient as default for unknown type', () => {
const provider = createTestProvider('unknown', 'unknown-type')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试边界条件
it('should handle provider with minimal configuration', () => {
const provider = createTestProvider('minimal', 'openai')
const client = ApiClientFactory.create(provider)
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
expect(client).toBeDefined()
})
// 测试特殊 ID 优先级高于类型
it('should prioritize special ID over type', () => {
const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic
const client = ApiClientFactory.create(provider)
// 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
expect(AnthropicAPIClient).not.toHaveBeenCalled()
expect(client).toBeDefined()
})
})
})

View File

@@ -1,38 +0,0 @@
import { describe, expect, it } from 'vitest'
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
describe('normalizeAzureOpenAIEndpoint', () => {
it.each([
{
apiHost: 'https://example.openai.azure.com/openai',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/v1',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/openai/v1/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/',
expectedEndpoint: 'https://example.openai.azure.com'
},
{
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
expectedEndpoint: 'https://example.openai.azure.com'
}
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
})
})

View File

@@ -1,353 +0,0 @@
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
import type { EndpointType, Model, Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/config/models', () => ({
SYSTEM_MODELS: {
defaultModel: [
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' }
],
zhipu: [],
silicon: [],
openai: [],
anthropic: [],
gemini: []
},
isOpenAIModel: vi.fn().mockReturnValue(true),
isOpenAILLMModel: vi.fn().mockReturnValue(true),
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
isGeminiLLMModel: vi.fn().mockReturnValue(false),
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
isVisionModel: vi.fn().mockReturnValue(false),
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
isReasoningModel: vi.fn().mockReturnValue(false),
isWebSearchModel: vi.fn().mockReturnValue(false),
findTokenLimit: vi.fn().mockReturnValue(4096),
isFunctionCallingModel: vi.fn().mockReturnValue(false),
DEFAULT_MAX_TOKENS: 4096,
qwenModel: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/services/FileManager', () => ({
default: class {
static async read() {
return 'test content'
}
static async write() {
return true
}
}
}))
vi.mock('@renderer/services/TokenService', () => ({
estimateTextTokens: vi.fn().mockReturnValue(100)
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: vi.fn().mockReturnValue({
debug: vi.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
silly: vi.fn()
})
}
}))
// 到底是谁想出来的在服务层调用 React Hook ?????????
// Mock additional services and hooks that might be imported
vi.mock('@renderer/hooks/useVertexAI', () => ({
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
getVertexAIServiceAccount: vi.fn().mockReturnValue({
privateKey: 'test-key',
clientEmail: 'test@example.com'
}),
isVertexAIConfigured: vi.fn().mockReturnValue(true),
isVertexProvider: vi.fn().mockReturnValue(true)
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn().mockReturnValue({}),
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
vi.mock('@renderer/utils/abortController', () => ({
addAbortController: vi.fn(),
removeAbortController: vi.fn()
}))
vi.mock('@anthropic-ai/sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@anthropic-ai/vertex-sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('openai', () => ({
default: vi.fn().mockImplementation(() => ({})),
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google-cloud/vertexai', () => ({
VertexAI: vi.fn().mockImplementation(() => ({}))
}))
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
const MockAnthropicVertexClient = vi.fn()
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
return {
AnthropicVertexClient: MockAnthropicVertexClient
}
})
// Helper to create test provider
const createTestProvider = (id: string, type: string): Provider => ({
id,
type: type as Provider['type'],
name: 'Test Provider',
apiKey: 'test-key',
apiHost: 'https://api.test.com',
models: []
})
// Helper to create test model
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
id,
name: 'Test Model',
provider: provider || 'test',
type: [],
group: 'test',
endpoint_type: endpointType as EndpointType
})
describe('Client Compatibility Types', () => {
let openaiProvider: Provider
let anthropicProvider: Provider
let geminiProvider: Provider
let azureProvider: Provider
let aihubmixProvider: Provider
let newApiProvider: Provider
let vertexProvider: Provider
beforeEach(() => {
vi.clearAllMocks()
openaiProvider = createTestProvider('openai', 'openai')
anthropicProvider = createTestProvider('anthropic', 'anthropic')
geminiProvider = createTestProvider('gemini', 'gemini')
azureProvider = createTestProvider('azure-openai', 'azure-openai')
aihubmixProvider = createTestProvider('aihubmix', 'openai')
newApiProvider = createTestProvider('new-api', 'openai')
vertexProvider = createTestProvider('vertex', 'vertexai')
})
describe('Direct API Clients', () => {
it('should return correct compatibility type for OpenAIAPIClient', () => {
const client = new OpenAIAPIClient(openaiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return correct compatibility type for AnthropicAPIClient', () => {
const client = new AnthropicAPIClient(anthropicProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
})
it('should return correct compatibility type for GeminiAPIClient', () => {
const client = new GeminiAPIClient(geminiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
})
})
describe('Decorator Pattern API Clients', () => {
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const testModel = createTestModel('gpt-4', 'azure-openai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
})
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const testModel = createTestModel('gpt-4', 'openai')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return NewAPIClient for NewAPIClient without model', () => {
const client = new NewAPIClient(newApiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['NewAPIClient'])
})
it('should delegate to underlying client for NewAPIClient with model', () => {
const client = new NewAPIClient(newApiProvider)
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return VertexAPIClient for VertexAPIClient without model', () => {
const client = new VertexAPIClient(vertexProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
})
it('should delegate to underlying client for VertexAPIClient with model', () => {
const client = new VertexAPIClient(vertexProvider)
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
})
})
describe('Middleware Compatibility Logic', () => {
it('should correctly identify OpenAI compatible clients', () => {
const openaiClient = new OpenAIAPIClient(openaiProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiTypes = openaiClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
// Test the logic from completions method line 94
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(openaiTypes)).toBe(true)
expect(isOpenAICompatible(responseTypes)).toBe(true)
})
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiClient = new OpenAIAPIClient(openaiProvider)
const anthropicTypes = anthropicClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
const openaiTypes = openaiClient.getClientCompatibilityType()
// Test the logic from completions method line 101
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
})
it('should handle non-compatible clients correctly', () => {
const geminiClient = new GeminiAPIClient(geminiProvider)
const geminiTypes = geminiClient.getClientCompatibilityType()
// Test that Gemini is not OpenAI compatible
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
// Test that Gemini is not Anthropic/OpenAIResponse compatible
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(geminiTypes)).toBe(false)
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
})
})
describe('Factory Integration', () => {
it('should return correct compatibility types for factory-created clients', () => {
const testCases = [
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
]
testCases.forEach(({ provider, expectedType }) => {
const client = ApiClientFactory.create(provider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toContain(expectedType)
})
})
})
})

View File

@@ -1,96 +0,0 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
import type { BaseApiClient } from '../BaseApiClient'
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
/**
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
* 使用装饰器模式实现在ApiClient层面进行模型路由
*/
export class AihubmixAPIClient extends MixedBaseAPIClient {
// 使用联合类型而不是any保持类型安全
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
new Map()
protected defaultClient: OpenAIAPIClient
protected currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
const providerExtraHeaders = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
// 初始化各个client - 现在有类型安全
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
this.clients.set('openai', openaiClient)
this.clients.set('default', defaultClient)
// 设置默认client
this.defaultClient = defaultClient
this.currentClient = this.defaultClient as BaseApiClient
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* 根据模型获取合适的client
*/
protected getClient(model: Model): BaseApiClient {
const id = model.id.toLowerCase()
// claude开头
if (id.startsWith('claude')) {
const client = this.clients.get('claude')
if (!client || !this.isValidClient(client)) {
throw new Error('Claude client not properly initialized')
}
return client
}
// gemini开头 且不以-nothink、-search结尾
if (
(id.startsWith('gemini') || id.startsWith('imagen')) &&
!id.endsWith('-nothink') &&
!id.endsWith('-search') &&
!id.includes('embedding')
) {
const client = this.clients.get('gemini')
if (!client || !this.isValidClient(client)) {
throw new Error('Gemini client not properly initialized')
}
return client
}
// OpenAI系列模型 不包含gpt-oss
if (isOpenAILLMModel(model) && !model.id.includes('gpt-oss')) {
const client = this.clients.get('openai')
if (!client || !this.isValidClient(client)) {
throw new Error('OpenAI client not properly initialized')
}
return client
}
return this.defaultClient as BaseApiClient
}
}

View File

@@ -1,788 +0,0 @@
import type Anthropic from '@anthropic-ai/sdk'
import type {
Base64ImageSource,
ImageBlockParam,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUseBlock,
WebSearchTool20250305
} from '@anthropic-ai/sdk/resources'
import type {
ContentBlock,
ContentBlockParam,
MessageCreateParamsBase,
RedactedThinkingBlockParam,
ServerToolUseBlockParam,
ThinkingBlockParam,
ThinkingConfigParam,
ToolUnion,
ToolUseBlockParam,
WebSearchResultBlock,
WebSearchToolResultBlockParam,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import type AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
Assistant,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import { EFFORT_RATIO, FILE_TYPE, WEB_SEARCH_SOURCE } from '@renderer/types'
import type {
ErrorChunk,
LLMWebSearchCompleteChunk,
LLMWebSearchInProgressChunk,
MCPToolCreatedChunk,
TextDeltaChunk,
TextStartChunk,
ThinkingDeltaChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { type Message } from '@renderer/types/newMessage'
import type {
AnthropicSdkMessageParam,
AnthropicSdkParams,
AnthropicSdkRawChunk,
AnthropicSdkRawOutput
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isSupportedToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import type { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
const logger = loggerService.withContext('AnthropicAPIClient')
export class AnthropicAPIClient extends BaseApiClient<
Anthropic | AnthropicVertex,
AnthropicSdkParams,
AnthropicSdkRawOutput,
AnthropicSdkRawChunk,
AnthropicSdkMessageParam,
ToolUseBlock,
ToolUnion
> {
oauthToken: string | undefined = undefined
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
constructor(provider: Provider) {
super(provider)
}
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
if (this.provider.authType === 'oauth') {
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
}
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
return this.sdkInstance
}
override async createCompletions(
payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> {
if (this.provider.authType === 'oauth') {
payload.system = buildClaudeCodeSystemMessage(payload.system)
}
const sdk = (await this.getSdkInstance()) as Anthropic
if (payload.stream) {
return sdk.messages.stream(payload, options)
}
return sdk.messages.create(payload, options)
}
// @ts-ignore sdk未提供
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
return []
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
const sdk = (await this.getSdkInstance()) as Anthropic
// prevent auto appended /v1. It's included in baseUrl.
const response = await sdk.models.list({ path: '/models' })
return response.data
}
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTemperature(assistant, model)
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return super.getTopP(assistant, model)
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
private static isValidBase64ImageMediaType(mime: string): mime is Base64ImageSource['media_type'] {
return ['image/jpeg', 'image/png', 'image/gif', 'image/webp'].includes(mime)
}
/**
* Get the message parameter
* @param message - The message
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const { textContent, imageContents } = await this.getMessageContent(message)
const parts: MessageParam['content'] = [
{
type: 'text',
text: textContent
}
]
if (imageContents.length > 0) {
for (const imageContent of imageContents) {
const base64Data = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
base64Data.mime = base64Data.mime.replace('jpg', 'jpeg')
if (AnthropicAPIClient.isValidBase64ImageMediaType(base64Data.mime)) {
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime,
type: 'base64'
}
})
} else {
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
}
}
}
// Get and process image blocks
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
// Handle uploaded file
const file = imageBlock.file
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
}
// Get and process file blocks
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const { file } = fileBlock
if ([FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => file.type === type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file)
parts.push({
type: 'document',
source: {
type: 'base64',
media_type: 'application/pdf',
data: base64Data
}
})
} else {
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
return mcpToolsToAnthropicTools(mcpTools)
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): AnthropicSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
// Implementing abstract methods from BaseApiClient
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
// This might need adjustment based on how tool calls are specifically handled in the new structure
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
return mcpTool
}
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
}
override buildSdkMessages(
currentReqMessages: AnthropicSdkMessageParam[],
output: Anthropic.Message,
toolResults: AnthropicSdkMessageParam[]
): AnthropicSdkMessageParam[] {
const assistantMessage: AnthropicSdkMessageParam = {
role: output.role,
content: convertContentBlocksToParams(output.content)
}
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
if (toolResults && toolResults.length > 0) {
newMessages.push(...toolResults)
}
return newMessages
}
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
if (typeof message.content === 'string') {
return estimateTextTokens(message.content)
}
return message.content
.map((content) => {
switch (content.type) {
case 'text':
return estimateTextTokens(content.text)
case 'image':
if (content.source.type === 'base64') {
return estimateTextTokens(content.source.data)
} else {
return estimateTextTokens(content.source.url)
}
case 'tool_use':
return estimateTextTokens(JSON.stringify(content.input))
case 'tool_result':
return estimateTextTokens(JSON.stringify(content.content))
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
const messageParam: AnthropicSdkMessageParam = {
role: message.role,
content: convertContentBlocksToParams(message.content)
}
return messageParam
}
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
return sdkPayload.messages || []
}
/**
* Anthropic专用的原始流监听器
* 处理MessageStream对象的特定事件
*/
attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
logger.debug(`Attaching stream listener to raw output`)
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
logger.debug(`Detected Anthropic MessageStream, attaching specialized listener`)
if (listener.onStart) {
listener.onStart()
}
if (listener.onChunk) {
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
listener.onChunk!(event)
})
}
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
if (anthropicListener.onMessage) {
rawOutput.on('finalMessage', anthropicListener.onMessage)
}
if (listener.onEnd) {
rawOutput.on('end', () => {
listener.onEnd!()
})
}
if (listener.onError) {
rawOutput.on('error', (error: Error) => {
listener.onError!(error)
})
}
return rawOutput
}
if (anthropicListener.onMessage) {
anthropicListener.onMessage(rawOutput)
}
// 对于非MessageStream响应
return rawOutput
}
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
if (!isWebSearchModel(model)) {
return undefined
}
return {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
} as WebSearchTool20250305
}
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: AnthropicSdkParams
messages: AnthropicSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
const systemPrompt = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isSupportedToolUse(assistant)
})
const systemMessage: TextBlockParam | undefined = systemPrompt
? { type: 'text', text: systemPrompt }
: undefined
// 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') {
sdkMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
sdkMessages.push(await this.convertMessageToSdkParam(message))
}
}
if (enableWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
tools.push(webSearchTool)
}
}
const commonParams: MessageCreateParamsBase = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: sdkMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
stream: streamOutput,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
}
const timeout = this.getTimeout(model)
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
return () => {
let accumulatedJson = ''
const toolCalls: Record<number, ToolUseBlock> = {}
return {
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
if (typeof rawChunk === 'string') {
try {
rawChunk = JSON.parse(rawChunk)
} catch (error) {
logger.error('invalid chunk', { rawChunk, error })
throw new Error(t('error.chat.chunk.non_json'))
}
}
switch (rawChunk.type) {
case 'message': {
let i = 0
let hasTextContent = false
let hasThinkingContent = false
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
if (!hasTextContent) {
controller.enqueue({
type: ChunkType.TEXT_START
} as TextStartChunk)
hasTextContent = true
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: content.text
} as TextDeltaChunk)
break
}
case 'tool_use': {
toolCalls[i] = content
i++
break
}
case 'thinking': {
if (!hasThinkingContent) {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
hasThinkingContent = true
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: content.thinking
} as ThinkingDeltaChunk)
break
}
case 'web_search_tool_result': {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: content.content,
source: WEB_SEARCH_SOURCE.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
break
}
}
}
if (i > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: Object.values(toolCalls)
} as MCPToolCreatedChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
break
}
case 'content_block_start': {
const contentBlock = rawChunk.content_block
switch (contentBlock.type) {
case 'server_tool_use': {
if (contentBlock.name === 'web_search') {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
} as LLMWebSearchInProgressChunk)
}
break
}
case 'web_search_tool_result': {
if (
contentBlock.content &&
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
) {
controller.enqueue({
type: ChunkType.ERROR,
error: {
code: (contentBlock.content as WebSearchToolResultError).error_code,
message: (contentBlock.content as WebSearchToolResultError).error_code
}
} as ErrorChunk)
} else {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: contentBlock.content as Array<WebSearchResultBlock>,
source: WEB_SEARCH_SOURCE.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
}
break
}
case 'tool_use': {
toolCalls[rawChunk.index] = contentBlock
break
}
case 'text': {
controller.enqueue({
type: ChunkType.TEXT_START
} as TextStartChunk)
break
}
case 'thinking':
case 'redacted_thinking': {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
break
}
}
break
}
case 'content_block_delta': {
const messageDelta = rawChunk.delta
switch (messageDelta.type) {
case 'text_delta': {
if (messageDelta.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: messageDelta.text
} as TextDeltaChunk)
}
break
}
case 'thinking_delta': {
if (messageDelta.thinking) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: messageDelta.thinking
} as ThinkingDeltaChunk)
}
break
}
case 'input_json_delta': {
if (messageDelta.partial_json) {
accumulatedJson += messageDelta.partial_json
}
break
}
}
break
}
case 'content_block_stop': {
const toolCall = toolCalls[rawChunk.index]
if (toolCall) {
try {
toolCall.input = accumulatedJson ? JSON.parse(accumulatedJson) : {}
logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [toolCall]
} as MCPToolCreatedChunk)
} catch (error) {
logger.error('Error parsing tool call input:', error as Error)
}
}
break
}
case 'message_delta': {
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
}
}
}
}
}
}
}
/**
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
* 去除服务器生成的额外字段只保留发送给API所需的字段
*/
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
return contentBlocks.map((block): ContentBlockParam => {
switch (block.type) {
case 'text':
// TextBlock -> TextBlockParam去除 citations 等服务器字段
return {
type: 'text',
text: block.text
} satisfies TextBlockParam
case 'tool_use':
// ToolUseBlock -> ToolUseBlockParam
return {
type: 'tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ToolUseBlockParam
case 'thinking':
// ThinkingBlock -> ThinkingBlockParam
return {
type: 'thinking',
thinking: block.thinking,
signature: block.signature
} satisfies ThinkingBlockParam
case 'redacted_thinking':
// RedactedThinkingBlock -> RedactedThinkingBlockParam
return {
type: 'redacted_thinking',
data: block.data
} satisfies RedactedThinkingBlockParam
case 'server_tool_use':
// ServerToolUseBlock -> ServerToolUseBlockParam
return {
type: 'server_tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ServerToolUseBlockParam
case 'web_search_tool_result':
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
return {
type: 'web_search_tool_result',
tool_use_id: block.tool_use_id,
content: block.content
} satisfies WebSearchToolResultBlockParam
default:
return block as ContentBlockParam
}
})
}

View File

@@ -1,104 +0,0 @@
import type Anthropic from '@anthropic-ai/sdk'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import type { Provider } from '@renderer/types'
import { isEmpty } from 'lodash'
import { AnthropicAPIClient } from './AnthropicAPIClient'
const logger = loggerService.withContext('AnthropicVertexClient')
export class AnthropicVertexClient extends AnthropicAPIClient {
sdkInstance: AnthropicVertex | undefined = undefined
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
constructor(provider: Provider) {
super(provider)
}
private formatApiHost(host: string): string {
const forceUseOriginalHost = () => {
return host.endsWith('/')
}
if (!host) {
return host
}
return forceUseOriginalHost() ? host : `${host}/v1/`
}
override getBaseURL() {
return this.formatApiHost(this.provider.apiHost)
}
override async getSdkInstance(): Promise<AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const location = getVertexAILocation()
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
throw new Error('Vertex AI settings are not configured')
}
const authHeaders = await this.getServiceAccountAuthHeaders()
this.sdkInstance = new AnthropicVertex({
projectId: projectId,
region: location,
dangerouslyAllowBrowser: true,
defaultHeaders: authHeaders,
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
})
return this.sdkInstance
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
throw new Error('Vertex AI does not support listModels method.')
}
/**
* 获取认证头,如果配置了 service account 则从主进程获取
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
// 检查是否配置了 service account
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
return undefined
}
// 检查是否已有有效的认证头(提前 5 分钟过期)
const now = Date.now()
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
return this.authHeaders
}
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId,
serviceAccount: {
privateKey: serviceAccount.privateKey,
clientEmail: serviceAccount.clientEmail
}
})
// 设置过期时间(通常认证头有效期为 1 小时)
this.authHeadersExpiry = now + 60 * 60 * 1000
return this.authHeaders
} catch (error: any) {
logger.error('Failed to get auth headers:', error)
throw new Error(`Service Account authentication failed: ${error.message}`)
}
}
}

View File

@@ -1,51 +0,0 @@
import type OpenAI from '@cherrystudio/openai'
import type { Provider } from '@renderer/types'
import type { OpenAISdkParams, OpenAISdkRawOutput } from '@renderer/types/sdk'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
export class CherryAiAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(
payload: OpenAISdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAISdkRawOutput> {
const sdk = await this.getSdkInstance()
options = options || {}
options.headers = options.headers || {}
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: payload
})
options.headers = {
...options.headers,
...signature
}
// @ts-ignore - SDK参数可能有额外的字段
return await sdk.chat.completions.create(payload, options)
}
override getClientCompatibilityType(): string[] {
return ['CherryAiAPIClient']
}
public async listModels(): Promise<OpenAI.Models.Model[]> {
const models = ['Qwen/Qwen3-8B', 'Qwen/Qwen3-Next-80B-A3B-Instruct']
const created = Date.now()
return models.map((id) => ({
id,
owned_by: 'cherryai',
object: 'model' as const,
created
}))
}
}

Some files were not shown because too many files have changed in this diff Show More