mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-07-03 12:27:41 +08:00
refactor: migrate to ai sdk v6 Phase 3 (#12235)
Continued from #12227 ## Summary Phase 3 of AI SDK v6 migration: - **Image generation**: Migrate to native AI SDK `generateImage`/`editImage`, remove legacy image middleware - **Embedding**: Migrate to AI SDK `embedMany`, remove legacy embedding clients - **Model listing**: Refactor `ModelListService` to Strategy Registry pattern (`listModels.ts`), consolidate 7 schema files into one `schemas.ts` - **OpenRouter image**: Bump `@openrouter/ai-sdk-provider` to 2.3.3 with native image endpoint support, remove `isNativeImageGenerationProvider` guard - **GitHub Copilot**: Simplify extension by removing `ProviderV2` cast and `wrapProvider` - **Legacy removal**: Delete all legacy API clients, middleware pipeline, and barrel `index.ts` - **Rename**: `index_new.ts` → `AiProvider.ts`, `ModelListService.ts` → `listModels.ts` ## Manual Test Plan ### P0 — Core paths + PR change focus #### 1. Core Chat - [ ] OpenAI model (e.g. GPT-4o): send text, verify streaming output - [ ] Anthropic model (e.g. Claude Sonnet): verify chat works - [ ] Gemini model: verify chat works - [ ] DeepSeek model: verify chat works - [ ] Reasoning mode (o3-mini, DeepSeek R1): verify thinking process displays - [ ] Send message with image (multimodal): verify model recognizes image - [ ] Abort mid-stream: verify clean termination, no errors #### 2. Image Generation (key change area) - [ ] DALL-E 3 / GPT-Image-1: verify image generation works - [ ] OpenRouter image model: verify **native image endpoint** is used (no longer chat completions) - [ ] Image editing: upload image + text prompt, verify editImage path - [ ] Verify generated images display correctly (both base64 and URL formats) #### 3. Model Listing (refactored area) - [ ] Sync OpenAI models: verify names and groups - [ ] Sync Gemini models: verify `models/` prefix stripped - [ ] Sync Ollama models (local): verify display - [ ] Sync OpenRouter models: verify chat + embedding models both appear - [ ] Sync SiliconFlow models: verify grouping (e.g. `deepseek-ai/`) - [ ] Sync GitHub Models - [ ] Sync Together models - [ ] Sync NewAPI service (e.g. CherryIn) - [ ] Sync PPIO models: verify chat + embedding + reranker endpoints merged ### P1 — Affected by refactor #### 4. Provider Extensions - [ ] GitHub Copilot: verify chat works (simplified wrapProvider logic) - [ ] Ollama: verify chat works - [ ] Vertex AI / Bedrock: verify chat works (if configured) #### 5. MCP Tool Calling - [ ] Enable MCP server, send tool-requiring message, verify tool execution - [ ] Verify both prompt tool use and function calling modes #### 6. Auxiliary Functions - [ ] Auto topic title generation (fetchMessagesSummary) - [ ] Note summary (fetchNoteSummary) - [ ] Translation/generation (fetchGenerate) - [ ] API check (click "Check" button in settings) ### P2 — Indirect impact #### 7. Knowledge Base - [ ] Create knowledge base, associate with assistant, verify RAG retrieval and embedding #### 8. Web Search - [ ] Enable native web search (Gemini/Perplexity), verify search results in response #### 9. Tracing - [ ] Enable developer mode, send message, verify trace spans recorded and displayed --------- Signed-off-by: suyao <sy20010504@gmail.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: icarus <eurfelux@gmail.com> Co-authored-by: fullex <106392080+0xfullex@users.noreply.github.com>
This commit is contained in:
11
.changeset/clean-aicore-exports.md
Normal file
11
.changeset/clean-aicore-exports.md
Normal file
@@ -0,0 +1,11 @@
|
||||
---
|
||||
'@cherrystudio/ai-core': minor
|
||||
---
|
||||
|
||||
Remove unused exports, dead types, and over-engineered abstractions from aiCore
|
||||
|
||||
- Remove unused public exports: `createOpenAICompatibleExecutor`, `create*Options`, `mergeProviderOptions`, `PluginManager`, `createContext`, `AI_CORE_VERSION`, `AI_CORE_NAME`, `BUILT_IN_PLUGIN_PREFIX`, `registeredProviderIds`, `ProviderInitializationError`, `ProviderExtensionBuilder`, `createProviderExtension`
|
||||
- Delete dead type definitions: `HookResult`, `PluginManagerConfig`, `AiRequestMetadata`, `ExtractProviderOptions`, `ProviderOptions`, `CoreProviderSettingsMap` (re-added as internal), `ExtractExtensionIds`, `ExtractExtensionSettings`
|
||||
- Remove over-engineered `ExtensionStorage` system: delete `ExtensionStorage`, `StorageAccessor`, `ExtensionContext`, `ExtensionHook`, `LifecycleHooks` types; remove `TStorage` generic parameter from `ProviderExtension` (4 → 3 type params); remove `_storage`, `storage` getter, `createContext`, `executeHook`, `initialStorage`, `hooks` from class and config
|
||||
- Delete `create*Options` convenience functions and inline `createOpenRouterOptions` at its only call site
|
||||
- Delete `DEFAULT_WEB_SEARCH_CONFIG` and plugins `README.md`
|
||||
1749
docs/en/guides/ai-core-architecture.md
Normal file
1749
docs/en/guides/ai-core-architecture.md
Normal file
File diff suppressed because it is too large
Load Diff
2288
docs/zh/guides/ai-core-architecture.md
Normal file
2288
docs/zh/guides/ai-core-architecture.md
Normal file
File diff suppressed because it is too large
Load Diff
47
package.json
47
package.json
@@ -38,7 +38,7 @@
|
||||
"agents:drop": "NODE_ENV='development' drizzle-kit drop --config src/main/services/agents/drizzle.config.ts",
|
||||
"analyze:renderer": "VISUALIZER_RENDERER=true pnpm build",
|
||||
"analyze:main": "VISUALIZER_MAIN=true pnpm build",
|
||||
"typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"",
|
||||
"typecheck": "pnpm --filter @cherrystudio/ai-sdk-provider build && concurrently -n \"node,web,aicore\" -c \"cyan,magenta,yellow\" \"npm run typecheck:node\" \"npm run typecheck:web\" \"pnpm --filter @cherrystudio/ai-core typecheck\"",
|
||||
"typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false",
|
||||
"typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false",
|
||||
"i18n:check": "dotenv -e .env -- tsx scripts/check-i18n.ts",
|
||||
@@ -110,28 +110,28 @@
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^4.0.67",
|
||||
"@ai-sdk/anthropic": "^3.0.48",
|
||||
"@ai-sdk/azure": "^3.0.37",
|
||||
"@ai-sdk/cerebras": "^2.0.34",
|
||||
"@ai-sdk/gateway": "^3.0.57",
|
||||
"@ai-sdk/google": "^3.0.33",
|
||||
"@ai-sdk/google-vertex": "^4.0.66",
|
||||
"@ai-sdk/huggingface": "^1.0.32",
|
||||
"@ai-sdk/mistral": "^3.0.20",
|
||||
"@ai-sdk/openai": "^3.0.36",
|
||||
"@ai-sdk/perplexity": "^3.0.19",
|
||||
"@ai-sdk/amazon-bedrock": "^4.0.77",
|
||||
"@ai-sdk/anthropic": "^3.0.58",
|
||||
"@ai-sdk/azure": "^3.0.42",
|
||||
"@ai-sdk/cerebras": "^2.0.39",
|
||||
"@ai-sdk/cohere": "^3.0.25",
|
||||
"@ai-sdk/gateway": "^3.0.66",
|
||||
"@ai-sdk/google": "^3.0.43",
|
||||
"@ai-sdk/google-vertex": "^4.0.80",
|
||||
"@ai-sdk/huggingface": "^1.0.37",
|
||||
"@ai-sdk/mistral": "^3.0.24",
|
||||
"@ai-sdk/openai": "^3.0.41",
|
||||
"@ai-sdk/openai-compatible": "^2.0.35",
|
||||
"@ai-sdk/perplexity": "^3.0.23",
|
||||
"@ai-sdk/provider": "^3.0.8",
|
||||
"@ai-sdk/provider-utils": "^4.0.15",
|
||||
"@ai-sdk/provider-utils": "^4.0.19",
|
||||
"@ai-sdk/test-server": "^1.0.3",
|
||||
"@ai-sdk/xai": "^3.0.59",
|
||||
"@ai-sdk/togetherai": "^2.0.39",
|
||||
"@ai-sdk/xai": "^3.0.67",
|
||||
"@ant-design/cssinjs": "1.23.0",
|
||||
"@ant-design/icons": "5.6.1",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "0.11.4",
|
||||
"@aws-sdk/client-bedrock": "^3.998.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.998.0",
|
||||
"@aws-sdk/client-s3": "^3.998.0",
|
||||
"@biomejs/biome": "2.2.4",
|
||||
"@changesets/changelog-github": "^0.5.2",
|
||||
@@ -171,7 +171,7 @@
|
||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||
"@eslint/js": "^9.22.0",
|
||||
"@floating-ui/dom": "1.7.3",
|
||||
"@google/genai": "1.0.1",
|
||||
"@google/genai": "^1.46.0",
|
||||
"@hello-pangea/dnd": "^18.0.1",
|
||||
"@iconify-json/material-icon-theme": "^1.2.56",
|
||||
"@iconify/react": "^6.0.2",
|
||||
@@ -185,7 +185,7 @@
|
||||
"@modelcontextprotocol/sdk": "1.27.1",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^2.2.3",
|
||||
"@openrouter/ai-sdk-provider": "^2.3.3",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/context-async-hooks": "2.0.1",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
@@ -273,7 +273,7 @@
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"adm-zip": "0.4.16",
|
||||
"ai": "^6.0.103",
|
||||
"ai": "^6.0.116",
|
||||
"ansi-to-react": "^6.2.6",
|
||||
"antd": "5.27.0",
|
||||
"archiver": "^7.0.1",
|
||||
@@ -430,6 +430,7 @@
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "npm:rolldown-vite@7.3.0",
|
||||
"vitest": "^3.2.4",
|
||||
"voyage-ai-provider": "^3.0.0",
|
||||
"webdav": "^5.9.0",
|
||||
"winston": "^3.17.0",
|
||||
"winston-daily-rotate-file": "^5.0.0",
|
||||
@@ -473,8 +474,6 @@
|
||||
"patchedDependencies": {
|
||||
"@napi-rs/system-ocr@1.0.2": "patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"tesseract.js@6.0.1": "patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||
"@anthropic-ai/vertex-sdk@0.11.4": "patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
"@google/genai@1.0.1": "patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
|
||||
"@langchain/core@1.0.2": "patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
||||
"@langchain/openai@1.0.0": "patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||
"@tiptap/extension-drag-handle@3.2.0": "patches/@tiptap-extension-drag-handle-npm-3.2.0-5a9ebff7c9.patch",
|
||||
@@ -484,9 +483,9 @@
|
||||
"file-stream-rotator@0.6.1": "patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
|
||||
"libsql@0.4.7": "patches/libsql-npm-0.4.7-444e260fb1.patch",
|
||||
"pdf-parse@1.1.1": "patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||
"@ai-sdk/openai-compatible@2.0.30": "patches/@ai-sdk__openai-compatible@2.0.30.patch",
|
||||
"@openrouter/ai-sdk-provider": "patches/@openrouter__ai-sdk-provider.patch",
|
||||
"ollama-ai-provider-v2@3.3.1": "patches/ollama-ai-provider-v2@3.3.1.patch",
|
||||
"@ai-sdk/openai-compatible@2.0.35": "patches/@ai-sdk__openai-compatible@2.0.35.patch",
|
||||
"@openrouter/ai-sdk-provider": "patches/@openrouter__ai-sdk-provider.patch",
|
||||
"@opeoginni/github-copilot-openai-compatible@1.0.0": "patches/@opeoginni__github-copilot-openai-compatible@1.0.0.patch"
|
||||
},
|
||||
"onlyBuiltDependencies": [
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
},
|
||||
"type": "module",
|
||||
"main": "dist/index.cjs",
|
||||
"module": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"module": "dist/index.mjs",
|
||||
"types": "dist/index.d.mts",
|
||||
"files": ["dist"],
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
@@ -28,14 +28,14 @@
|
||||
"prepublishOnly": "pnpm build"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@ai-sdk/anthropic": "^3.0.48",
|
||||
"@ai-sdk/google": "^3.0.33",
|
||||
"@ai-sdk/openai": "^3.0.36",
|
||||
"@ai-sdk/openai-compatible": "^2.0.30"
|
||||
"@ai-sdk/anthropic": "^3.0.58",
|
||||
"@ai-sdk/google": "^3.0.43",
|
||||
"@ai-sdk/openai": "^3.0.41",
|
||||
"@ai-sdk/openai-compatible": "^2.0.35"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "^3.0.8",
|
||||
"@ai-sdk/provider-utils": "^4.0.15"
|
||||
"@ai-sdk/provider-utils": "^4.0.19"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.20.3",
|
||||
@@ -48,10 +48,10 @@
|
||||
},
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js",
|
||||
"types": "./dist/index.d.mts",
|
||||
"import": "./dist/index.mjs",
|
||||
"require": "./dist/index.cjs",
|
||||
"default": "./dist/index.js"
|
||||
"default": "./dist/index.mjs"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,514 +0,0 @@
|
||||
# AI Core 基于 Vercel AI SDK 的技术架构
|
||||
|
||||
## 1. 架构设计理念
|
||||
|
||||
### 1.1 设计目标
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
|
||||
- **动态导入**:通过动态导入实现按需加载,减少打包体积
|
||||
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
|
||||
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
|
||||
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
|
||||
- **轻量级**:专注核心功能,保持包的轻量和高效
|
||||
- **包级独立**:作为独立包管理,便于复用和维护
|
||||
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
|
||||
|
||||
### 1.2 核心优势
|
||||
|
||||
- **标准化**:AI SDK 提供统一的模型接口,减少适配工作
|
||||
- **简化设计**:函数式API,避免过度抽象
|
||||
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
|
||||
- **性能优化**:AI SDK 内置优化和最佳实践
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **可扩展插件**:通用的流转换和参数处理插件系统
|
||||
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
|
||||
|
||||
## 2. 整体架构图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph "用户应用 (如 Cherry Studio)"
|
||||
UI["用户界面"]
|
||||
Components["应用组件"]
|
||||
end
|
||||
|
||||
subgraph "packages/aiCore (AI Core 包)"
|
||||
subgraph "Runtime Layer (运行时层)"
|
||||
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
|
||||
PluginEngine["PluginEngine (插件引擎)"]
|
||||
RuntimeAPI["Runtime API (便捷函数)"]
|
||||
end
|
||||
|
||||
subgraph "Models Layer (模型层)"
|
||||
ModelFactory["createModel() (模型工厂)"]
|
||||
ProviderCreator["ProviderCreator (提供商创建器)"]
|
||||
end
|
||||
|
||||
subgraph "Core Systems (核心系统)"
|
||||
subgraph "Plugins (插件)"
|
||||
PluginManager["PluginManager (插件管理)"]
|
||||
BuiltInPlugins["Built-in Plugins (内置插件)"]
|
||||
StreamTransforms["Stream Transforms (流转换)"]
|
||||
end
|
||||
|
||||
subgraph "Middleware (中间件)"
|
||||
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
|
||||
end
|
||||
|
||||
subgraph "Providers (提供商)"
|
||||
Registry["Provider Registry (注册表)"]
|
||||
Factory["Provider Factory (工厂)"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "Vercel AI SDK"
|
||||
AICore["ai (核心库)"]
|
||||
OpenAI["@ai-sdk/openai"]
|
||||
Anthropic["@ai-sdk/anthropic"]
|
||||
Google["@ai-sdk/google"]
|
||||
XAI["@ai-sdk/xai"]
|
||||
Others["其他 19+ Providers"]
|
||||
end
|
||||
|
||||
subgraph "Future: OpenAI Agents SDK"
|
||||
AgentSDK["@openai/agents (未来集成)"]
|
||||
AgentExtensions["Agent Extensions (预留)"]
|
||||
end
|
||||
|
||||
UI --> RuntimeAPI
|
||||
Components --> RuntimeExecutor
|
||||
RuntimeAPI --> RuntimeExecutor
|
||||
RuntimeExecutor --> PluginEngine
|
||||
RuntimeExecutor --> ModelFactory
|
||||
PluginEngine --> PluginManager
|
||||
ModelFactory --> ProviderCreator
|
||||
ModelFactory --> MiddlewareWrapper
|
||||
ProviderCreator --> Registry
|
||||
Registry --> Factory
|
||||
Factory --> OpenAI
|
||||
Factory --> Anthropic
|
||||
Factory --> Google
|
||||
Factory --> XAI
|
||||
Factory --> Others
|
||||
|
||||
RuntimeExecutor --> AICore
|
||||
AICore --> streamText
|
||||
AICore --> generateText
|
||||
AICore --> streamObject
|
||||
AICore --> generateObject
|
||||
|
||||
PluginManager --> StreamTransforms
|
||||
PluginManager --> BuiltInPlugins
|
||||
|
||||
%% 未来集成路径
|
||||
RuntimeExecutor -.-> AgentSDK
|
||||
AgentSDK -.-> AgentExtensions
|
||||
```
|
||||
|
||||
## 3. 包结构设计
|
||||
|
||||
### 3.1 新架构文件结构
|
||||
|
||||
```
|
||||
packages/aiCore/
|
||||
├── src/
|
||||
│ ├── core/ # 核心层 - 内部实现
|
||||
│ │ ├── models/ # 模型层 - 模型创建和配置
|
||||
│ │ │ ├── factory.ts # 模型工厂函数 ✅
|
||||
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
|
||||
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
|
||||
│ │ │ ├── types.ts # 模型类型定义 ✅
|
||||
│ │ │ └── index.ts # 模型层导出 ✅
|
||||
│ │ ├── runtime/ # 运行时层 - 执行和用户API
|
||||
│ │ │ ├── executor.ts # 运行时执行器 ✅
|
||||
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
|
||||
│ │ │ ├── types.ts # 运行时类型定义 ✅
|
||||
│ │ │ └── index.ts # 运行时导出 ✅
|
||||
│ │ ├── middleware/ # 中间件系统
|
||||
│ │ │ ├── wrapper.ts # 模型包装器 ✅
|
||||
│ │ │ ├── manager.ts # 中间件管理器 ✅
|
||||
│ │ │ ├── types.ts # 中间件类型 ✅
|
||||
│ │ │ └── index.ts # 中间件导出 ✅
|
||||
│ │ ├── plugins/ # 插件系统
|
||||
│ │ │ ├── types.ts # 插件类型定义 ✅
|
||||
│ │ │ ├── manager.ts # 插件管理器 ✅
|
||||
│ │ │ ├── built-in/ # 内置插件 ✅
|
||||
│ │ │ │ ├── logging.ts # 日志插件 ✅
|
||||
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
|
||||
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
|
||||
│ │ │ │ └── index.ts # 内置插件导出 ✅
|
||||
│ │ │ ├── README.md # 插件文档 ✅
|
||||
│ │ │ └── index.ts # 插件导出 ✅
|
||||
│ │ ├── providers/ # 提供商管理
|
||||
│ │ │ ├── registry.ts # 提供商注册表 ✅
|
||||
│ │ │ ├── factory.ts # 提供商工厂 ✅
|
||||
│ │ │ ├── creator.ts # 提供商创建器 ✅
|
||||
│ │ │ ├── types.ts # 提供商类型 ✅
|
||||
│ │ │ ├── utils.ts # 工具函数 ✅
|
||||
│ │ │ └── index.ts # 提供商导出 ✅
|
||||
│ │ ├── options/ # 配置选项
|
||||
│ │ │ ├── factory.ts # 选项工厂 ✅
|
||||
│ │ │ ├── types.ts # 选项类型 ✅
|
||||
│ │ │ ├── xai.ts # xAI 选项 ✅
|
||||
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
|
||||
│ │ │ ├── examples.ts # 示例配置 ✅
|
||||
│ │ │ └── index.ts # 选项导出 ✅
|
||||
│ │ └── index.ts # 核心层导出 ✅
|
||||
│ ├── types.ts # 全局类型定义 ✅
|
||||
│ └── index.ts # 包主入口文件 ✅
|
||||
├── package.json # 包配置文件 ✅
|
||||
├── tsconfig.json # TypeScript 配置 ✅
|
||||
├── README.md # 包说明文档 ✅
|
||||
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
|
||||
```
|
||||
|
||||
## 4. 架构分层详解
|
||||
|
||||
### 4.1 Models Layer (模型层)
|
||||
|
||||
**职责**:统一的模型创建和配置管理
|
||||
|
||||
**核心文件**:
|
||||
|
||||
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
|
||||
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
|
||||
- `types.ts`: 模型配置类型定义
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 函数式设计,避免不必要的类抽象
|
||||
- 统一的模型配置接口
|
||||
- 自动处理中间件应用
|
||||
- 支持批量模型创建
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 模型配置接口
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
// 核心模型创建函数
|
||||
export async function createModel(config: ModelConfig): Promise<LanguageModel>
|
||||
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
|
||||
```
|
||||
|
||||
### 4.2 Runtime Layer (运行时层)
|
||||
|
||||
**职责**:运行时执行器和用户面向的API接口
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `executor.ts`: 运行时执行器类
|
||||
- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient)
|
||||
- `index.ts`: 便捷函数和工厂方法
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 提供三种使用方式:类实例、静态工厂、函数式调用
|
||||
- 自动集成模型创建和插件处理
|
||||
- 完整的类型安全支持
|
||||
- 为 OpenAI Agents SDK 预留扩展接口
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 运行时执行器
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T>
|
||||
|
||||
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
|
||||
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
|
||||
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
|
||||
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
|
||||
}
|
||||
|
||||
// 便捷函数式API
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<StreamTextResult>
|
||||
```
|
||||
|
||||
### 4.3 Plugin System (插件系统)
|
||||
|
||||
**职责**:可扩展的插件架构
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `PluginManager`: 插件生命周期管理
|
||||
- `built-in/`: 内置插件集合
|
||||
- 流转换收集和应用
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 借鉴 Rollup 的钩子分类设计
|
||||
- 支持流转换 (`experimental_transform`)
|
||||
- 内置常用插件(日志、计数等)
|
||||
- 完整的生命周期钩子
|
||||
|
||||
**插件接口**:
|
||||
|
||||
```typescript
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理
|
||||
transformStream?: () => TransformStream
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 Middleware System (中间件系统)
|
||||
|
||||
**职责**:AI SDK原生中间件支持
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `ModelWrapper.ts`: 模型包装函数
|
||||
|
||||
**设计哲学**:
|
||||
|
||||
- 直接使用AI SDK的 `wrapLanguageModel`
|
||||
- 与插件系统分离,职责明确
|
||||
- 函数式设计,简化使用
|
||||
|
||||
```typescript
|
||||
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
|
||||
```
|
||||
|
||||
### 4.5 Provider System (提供商系统)
|
||||
|
||||
**职责**:AI Provider注册表和动态导入
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `registry.ts`: 19+ Provider配置和类型
|
||||
- `factory.ts`: Provider配置工厂
|
||||
|
||||
**支持的Providers**:
|
||||
|
||||
- OpenAI, Anthropic, Google, XAI
|
||||
- Azure OpenAI, Amazon Bedrock, Google Vertex
|
||||
- Groq, Together.ai, Fireworks, DeepSeek
|
||||
- 等19+ AI SDK官方支持的providers
|
||||
|
||||
## 5. 使用方式
|
||||
|
||||
### 5.1 函数式调用 (推荐 - 简单场景)
|
||||
|
||||
```typescript
|
||||
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 直接函数调用
|
||||
const stream = await streamText(
|
||||
'anthropic',
|
||||
{ apiKey: 'your-api-key' },
|
||||
'claude-3',
|
||||
{ messages: [{ role: 'user', content: 'Hello!' }] },
|
||||
[loggingPlugin]
|
||||
)
|
||||
```
|
||||
|
||||
### 5.2 执行器实例 (推荐 - 复杂场景)
|
||||
|
||||
```typescript
|
||||
import { createExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 创建可复用的执行器
|
||||
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
|
||||
|
||||
// 多次使用
|
||||
const stream = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
const result = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'How are you?' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 5.3 静态工厂方法
|
||||
|
||||
```typescript
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 静态创建
|
||||
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
|
||||
await executor.streamText('claude-3', { messages: [...] })
|
||||
```
|
||||
|
||||
### 5.4 直接模型创建 (高级用法)
|
||||
|
||||
```typescript
|
||||
import { createModel } from '@cherrystudio/ai-core/models'
|
||||
import { streamText } from 'ai'
|
||||
|
||||
// 直接创建模型使用
|
||||
const model = await createModel({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
options: { apiKey: 'your-api-key' },
|
||||
middlewares: [middleware1, middleware2]
|
||||
})
|
||||
|
||||
// 直接使用 AI SDK
|
||||
const result = await streamText({ model, messages: [...] })
|
||||
```
|
||||
|
||||
## 6. 为 OpenAI Agents SDK 预留的设计
|
||||
|
||||
### 6.1 架构兼容性
|
||||
|
||||
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
|
||||
|
||||
```typescript
|
||||
// 当前的模型创建
|
||||
const model = await createModel({
|
||||
providerId: 'anthropic',
|
||||
modelId: 'claude-3',
|
||||
options: { apiKey: 'xxx' }
|
||||
})
|
||||
|
||||
// 将来可以直接用于 OpenAI Agents SDK
|
||||
import { Agent, run } from '@openai/agents'
|
||||
|
||||
const agent = new Agent({
|
||||
model, // ✅ 直接兼容 LanguageModel 接口
|
||||
name: 'Assistant',
|
||||
instructions: '...',
|
||||
tools: [tool1, tool2]
|
||||
})
|
||||
|
||||
const result = await run(agent, 'user input')
|
||||
```
|
||||
|
||||
### 6.2 预留的扩展点
|
||||
|
||||
1. **runtime/agents/** 目录预留
|
||||
2. **AgentExecutor** 类预留
|
||||
3. **Agent工具转换插件** 预留
|
||||
4. **多Agent编排** 预留
|
||||
|
||||
### 6.3 未来架构扩展
|
||||
|
||||
```
|
||||
packages/aiCore/src/core/
|
||||
├── runtime/
|
||||
│ ├── agents/ # 🚀 未来添加
|
||||
│ │ ├── AgentExecutor.ts
|
||||
│ │ ├── WorkflowManager.ts
|
||||
│ │ └── ConversationManager.ts
|
||||
│ ├── executor.ts
|
||||
│ └── index.ts
|
||||
```
|
||||
|
||||
## 7. 架构优势
|
||||
|
||||
### 7.1 简化设计
|
||||
|
||||
- **移除过度抽象**:删除了orchestration层和creation层的复杂包装
|
||||
- **函数式优先**:models层使用函数而非类
|
||||
- **直接明了**:runtime层直接提供用户API
|
||||
|
||||
### 7.2 职责清晰
|
||||
|
||||
- **Models**: 专注模型创建和配置
|
||||
- **Runtime**: 专注执行和用户API
|
||||
- **Plugins**: 专注扩展功能
|
||||
- **Providers**: 专注AI Provider管理
|
||||
|
||||
### 7.3 类型安全
|
||||
|
||||
- 完整的 TypeScript 支持
|
||||
- AI SDK 类型的直接复用
|
||||
- 避免类型重复定义
|
||||
|
||||
### 7.4 灵活使用
|
||||
|
||||
- 三种使用模式满足不同需求
|
||||
- 从简单函数调用到复杂执行器
|
||||
- 支持直接AI SDK使用
|
||||
|
||||
### 7.5 面向未来
|
||||
|
||||
- 为 OpenAI Agents SDK 集成做好准备
|
||||
- 清晰的扩展点和架构边界
|
||||
- 模块化设计便于功能添加
|
||||
|
||||
## 8. 技术决策记录
|
||||
|
||||
### 8.1 为什么选择简化的两层架构?
|
||||
|
||||
- **职责分离**:models专注创建,runtime专注执行
|
||||
- **模块化**:每层都有清晰的边界和职责
|
||||
- **扩展性**:为Agent功能预留了清晰的扩展空间
|
||||
|
||||
### 8.2 为什么选择函数式设计?
|
||||
|
||||
- **简洁性**:避免不必要的类设计
|
||||
- **性能**:减少对象创建开销
|
||||
- **易用性**:函数调用更直观
|
||||
|
||||
### 8.3 为什么分离插件和中间件?
|
||||
|
||||
- **职责明确**: 插件处理应用特定需求
|
||||
- **原生支持**: 中间件使用AI SDK原生功能
|
||||
- **灵活性**: 两套系统可以独立演进
|
||||
|
||||
## 9. 总结
|
||||
|
||||
AI Core架构实现了:
|
||||
|
||||
### 9.1 核心特点
|
||||
|
||||
- ✅ **简化架构**: 2层核心架构,职责清晰
|
||||
- ✅ **函数式设计**: models层完全函数化
|
||||
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
|
||||
- ✅ **插件扩展**: 强大的插件系统
|
||||
- ✅ **多种使用方式**: 满足不同复杂度需求
|
||||
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
|
||||
|
||||
### 9.2 核心价值
|
||||
|
||||
- **统一接口**: 一套API支持19+ AI providers
|
||||
- **灵活使用**: 函数式、实例式、静态工厂式
|
||||
- **强类型**: 完整的TypeScript支持
|
||||
- **可扩展**: 插件和中间件双重扩展能力
|
||||
- **高性能**: 最小化包装,直接使用AI SDK
|
||||
- **面向未来**: Agent SDK集成架构就绪
|
||||
|
||||
### 9.3 未来发展
|
||||
|
||||
这个架构提供了:
|
||||
|
||||
- **优秀的开发体验**: 简洁的API和清晰的使用模式
|
||||
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
|
||||
- **良好的维护性**: 职责分离明确,代码易于维护
|
||||
- **广泛的适用性**: 既适合简单调用也适合复杂应用
|
||||
@@ -10,6 +10,10 @@
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"typecheck": "tsc --noEmit",
|
||||
"lint": "oxlint --fix && eslint . --fix --cache",
|
||||
"lint:check": "oxlint && eslint . --cache",
|
||||
"format": "biome format --write && biome lint --write",
|
||||
"format:check": "biome format && biome lint",
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
@@ -27,23 +31,27 @@
|
||||
},
|
||||
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||
"peerDependencies": {
|
||||
"@ai-sdk/google": "^3.0.33",
|
||||
"@ai-sdk/openai": "^3.0.36",
|
||||
"ai": "^6.0.103"
|
||||
"@ai-sdk/google": "^3.0.43",
|
||||
"@ai-sdk/openai": "^3.0.41",
|
||||
"ai": "^6.0.116"
|
||||
},
|
||||
"dependencies": {
|
||||
"@cherrystudio/ai-sdk-provider": "workspace:*",
|
||||
"@ai-sdk/anthropic": "^3.0.48",
|
||||
"@ai-sdk/azure": "^3.0.37",
|
||||
"@ai-sdk/deepseek": "^2.0.20",
|
||||
"@ai-sdk/openai-compatible": "^2.0.30",
|
||||
"@ai-sdk/anthropic": "^3.0.58",
|
||||
"@ai-sdk/azure": "^3.0.42",
|
||||
"@ai-sdk/deepseek": "^2.0.24",
|
||||
"@ai-sdk/openai-compatible": "^2.0.35",
|
||||
"@openrouter/ai-sdk-provider": "^2.3.3",
|
||||
"@ai-sdk/provider": "^3.0.8",
|
||||
"@ai-sdk/provider-utils": "^4.0.15",
|
||||
"@ai-sdk/xai": "^3.0.59",
|
||||
"@openrouter/ai-sdk-provider": "^2.2.3",
|
||||
"@ai-sdk/provider-utils": "^4.0.19",
|
||||
"@ai-sdk/xai": "^3.0.67",
|
||||
"lru-cache": "^11.2.4",
|
||||
"zod": "^4.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@cherrystudio/ai-sdk-provider": "workspace:*",
|
||||
"biome": "^0.3.3",
|
||||
"oxlint": "^1.36.0",
|
||||
"tsdown": "^0.20.3",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^3.2.4"
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
/**
|
||||
* Test Infrastructure Exports
|
||||
* Central export point for all test utilities, fixtures, and helpers
|
||||
*/
|
||||
|
||||
// Fixtures
|
||||
export * from './fixtures/mock-providers'
|
||||
export * from './fixtures/mock-responses'
|
||||
|
||||
// Helpers
|
||||
export * from './helpers/model-test-utils'
|
||||
export * from './helpers/provider-test-utils'
|
||||
export * from './helpers/test-utils'
|
||||
@@ -1,3 +0,0 @@
|
||||
# @cherryStudio-aiCore
|
||||
|
||||
Core
|
||||
@@ -2,13 +2,7 @@
|
||||
* Core 模块导出
|
||||
* 内部核心功能,供其他模块使用,不直接面向最终调用者
|
||||
*/
|
||||
|
||||
// 中间件系统
|
||||
export type { NamedMiddleware } from './middleware'
|
||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||
|
||||
// 创建管理
|
||||
export { globalModelResolver, ModelResolver } from './models'
|
||||
// 模型类型
|
||||
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||
|
||||
// 执行管理
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
/**
|
||||
* Middleware 模块导出
|
||||
* 提供通用的中间件管理能力
|
||||
*/
|
||||
|
||||
export { createMiddlewares } from './manager'
|
||||
export type { NamedMiddleware } from './types'
|
||||
export { wrapModelWithMiddlewares } from './wrapper'
|
||||
@@ -1,16 +0,0 @@
|
||||
/**
|
||||
* 中间件管理器
|
||||
* 专注于 AI SDK 中间件的管理,与插件系统分离
|
||||
*/
|
||||
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 创建中间件列表
|
||||
* 合并用户提供的中间件
|
||||
*/
|
||||
export function createMiddlewares(userMiddlewares: LanguageModelV3Middleware[] = []): LanguageModelV3Middleware[] {
|
||||
// 未来可以在这里添加默认的中间件
|
||||
const defaultMiddlewares: LanguageModelV3Middleware[] = []
|
||||
|
||||
return [...defaultMiddlewares, ...userMiddlewares]
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
/**
|
||||
* 中间件系统类型定义
|
||||
*/
|
||||
import type { LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 具名中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV3Middleware
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
/**
|
||||
* 模型包装工具函数
|
||||
* 用于将中间件应用到LanguageModel上
|
||||
*/
|
||||
import type { LanguageModelV3, LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
|
||||
/**
|
||||
* 使用中间件包装模型
|
||||
*/
|
||||
export function wrapModelWithMiddlewares(
|
||||
model: LanguageModelV3,
|
||||
middlewares: LanguageModelV3Middleware[]
|
||||
): LanguageModelV3 {
|
||||
if (middlewares.length === 0) {
|
||||
return model
|
||||
}
|
||||
|
||||
return wrapLanguageModel({
|
||||
model,
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
/**
|
||||
* 模型解析器 - models模块的核心
|
||||
* 负责将modelId解析为AI SDK的LanguageModel实例
|
||||
* 支持传统格式和命名空间格式
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
export class ModelResolver {
|
||||
/**
|
||||
* 核心方法:解析任意格式的modelId为语言模型
|
||||
*
|
||||
* @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
|
||||
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
|
||||
* @param providerOptions provider配置选项(用于OpenAI模式选择等)
|
||||
*/
|
||||
async resolveLanguageModel(
|
||||
modelId: string,
|
||||
fallbackProviderId: string,
|
||||
providerOptions?: any
|
||||
): Promise<LanguageModelV3> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||
finalProviderId = `${fallbackProviderId}-chat`
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedModel(modelId)
|
||||
} else {
|
||||
// 传统格式:使用处理后的 providerId + modelId
|
||||
return this.resolveTraditionalModel(finalProviderId, modelId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析文本嵌入模型
|
||||
*/
|
||||
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV3> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型
|
||||
*/
|
||||
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV3> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedImageModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的语言模型
|
||||
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
|
||||
*/
|
||||
private resolveNamespacedModel(modelId: string): LanguageModelV3 {
|
||||
return globalRegistryManagement.languageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的语言模型
|
||||
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV3 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的嵌入模型
|
||||
*/
|
||||
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV3 {
|
||||
return globalRegistryManagement.embeddingModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的嵌入模型
|
||||
*/
|
||||
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV3 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.embeddingModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的图像模型
|
||||
*/
|
||||
private resolveNamespacedImageModel(modelId: string): ImageModelV3 {
|
||||
return globalRegistryManagement.imageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的图像模型
|
||||
*/
|
||||
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV3 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局模型解析器实例
|
||||
*/
|
||||
export const globalModelResolver = new ModelResolver()
|
||||
@@ -1,47 +1,29 @@
|
||||
/**
|
||||
* ModelResolver Comprehensive Tests
|
||||
* Tests model resolution logic for language, embedding, and image models
|
||||
* Covers both traditional and namespaced format resolution
|
||||
* ProviderRegistry Model Resolution Tests
|
||||
* Tests model resolution via AI SDK's createProviderRegistry
|
||||
* The registry routes 'providerId:modelId' to the correct provider
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import {
|
||||
createMockEmbeddingModel,
|
||||
createMockImageModel,
|
||||
createMockLanguageModel,
|
||||
createMockProviderV3
|
||||
} from '@test-utils'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ModelResolver } from '../ModelResolver'
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn(),
|
||||
embeddingModel: vi.fn(),
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('../../middleware/wrapper', () => ({
|
||||
wrapModelWithMiddlewares: vi.fn((model: LanguageModelV3) => {
|
||||
// Return a wrapped model with a marker
|
||||
return {
|
||||
...model,
|
||||
_wrapped: true
|
||||
} as LanguageModelV3
|
||||
})
|
||||
}))
|
||||
|
||||
describe('ModelResolver', () => {
|
||||
let resolver: ModelResolver
|
||||
describe('ProviderRegistry Model Resolution', () => {
|
||||
let registry: ReturnType<typeof createProviderRegistry>
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
let mockProvider: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
resolver = new ModelResolver()
|
||||
|
||||
// Create properly typed mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
@@ -57,361 +39,168 @@ describe('ModelResolver', () => {
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Setup default mock implementations
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
vi.mocked(globalRegistryManagement.embeddingModel).mockReturnValue(mockEmbeddingModel)
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
})
|
||||
|
||||
describe('resolveLanguageModel', () => {
|
||||
describe('Traditional Format Resolution', () => {
|
||||
it('should resolve traditional format modelId without separator', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve with different provider and modelId combinations', async () => {
|
||||
const testCases: Array<{ modelId: string; providerId: string; expected: string }> = [
|
||||
{ modelId: 'claude-3-5-sonnet', providerId: 'anthropic', expected: 'anthropic|claude-3-5-sonnet' },
|
||||
{ modelId: 'gemini-2.0-flash', providerId: 'google', expected: 'google|gemini-2.0-flash' },
|
||||
{ modelId: 'grok-2-latest', providerId: 'xai', expected: 'xai|grok-2-latest' },
|
||||
{ modelId: 'deepseek-chat', providerId: 'deepseek', expected: 'deepseek|deepseek-chat' }
|
||||
]
|
||||
|
||||
for (const testCase of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(testCase.modelId, testCase.providerId)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(testCase.expected)
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle modelIds with special characters', async () => {
|
||||
const modelIds = ['model-v1.0', 'model_v2', 'model.2024', 'model:free']
|
||||
|
||||
for (const modelId of modelIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(modelId, 'provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`provider${DEFAULT_SEPARATOR}${modelId}`)
|
||||
}
|
||||
})
|
||||
mockProvider = createMockProviderV3({
|
||||
provider: 'test-provider',
|
||||
languageModel: vi.fn(() => mockLanguageModel),
|
||||
embeddingModel: vi.fn(() => mockEmbeddingModel),
|
||||
imageModel: vi.fn(() => mockImageModel)
|
||||
})
|
||||
|
||||
describe('Namespaced Format Resolution', () => {
|
||||
it('should resolve namespaced format with hub', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}anthropic${DEFAULT_SEPARATOR}claude-3-5-sonnet`
|
||||
|
||||
const result = await resolver.resolveLanguageModel(namespacedId, 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve simple namespaced format', async () => {
|
||||
const namespacedId = `provider${DEFAULT_SEPARATOR}model-id`
|
||||
|
||||
await resolver.resolveLanguageModel(namespacedId, 'fallback-provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`,
|
||||
`hub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model-v1.0`,
|
||||
`custom${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4-turbo`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveLanguageModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI Mode Selection', () => {
|
||||
it('should append "-chat" suffix for OpenAI provider with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-chat|gpt-4')
|
||||
})
|
||||
|
||||
it('should append "-chat" suffix for Azure provider with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'azure', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('azure-chat|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for OpenAI with responses mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', { mode: 'responses' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for OpenAI without mode', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should not append suffix for other providers with chat mode', async () => {
|
||||
await resolver.resolveLanguageModel('claude-3', 'anthropic', { mode: 'chat' })
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3')
|
||||
})
|
||||
|
||||
it('should handle namespaced IDs with OpenAI chat mode', async () => {
|
||||
const namespacedId = `hub${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}gpt-4`
|
||||
|
||||
await resolver.resolveLanguageModel(namespacedId, 'openai', { mode: 'chat' })
|
||||
|
||||
// Should use the namespaced ID directly, not apply mode logic
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(namespacedId)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Options Handling', () => {
|
||||
it('should pass provider options correctly', async () => {
|
||||
const options = { baseURL: 'https://api.example.com', apiKey: 'test-key' }
|
||||
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', options)
|
||||
|
||||
// Provider options are used for mode selection logic
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle empty provider options', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', {})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
|
||||
it('should handle undefined provider options', async () => {
|
||||
await resolver.resolveLanguageModel('gpt-4', 'openai', undefined)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|gpt-4')
|
||||
})
|
||||
registry = createProviderRegistry({
|
||||
'test-provider': mockProvider
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveTextEmbeddingModel', () => {
|
||||
describe('Traditional Format', () => {
|
||||
it('should resolve traditional embedding model ID', async () => {
|
||||
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
|
||||
describe('languageModel', () => {
|
||||
it('should resolve modelId via provider registry', () => {
|
||||
const result = registry.languageModel('test-provider:gpt-4')
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith('openai|text-embedding-ada-002')
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should resolve different embedding models', async () => {
|
||||
const testCases = [
|
||||
{ modelId: 'text-embedding-3-small', providerId: 'openai' },
|
||||
{ modelId: 'text-embedding-3-large', providerId: 'openai' },
|
||||
{ modelId: 'embed-english-v3.0', providerId: 'cohere' },
|
||||
{ modelId: 'voyage-2', providerId: 'voyage' }
|
||||
]
|
||||
|
||||
for (const { modelId, providerId } of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveTextEmbeddingModel(modelId, providerId)
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
|
||||
}
|
||||
})
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
describe('Namespaced Format', () => {
|
||||
it('should resolve namespaced embedding model ID', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}text-embedding-3-small`
|
||||
it('should pass various modelIds to the correct provider', () => {
|
||||
const modelIds = [
|
||||
'claude-3-5-sonnet',
|
||||
'gemini-2.0-flash',
|
||||
'grok-2-latest',
|
||||
'deepseek-chat',
|
||||
'model-v1.0',
|
||||
'model_v2',
|
||||
'model.2024'
|
||||
]
|
||||
|
||||
const result = await resolver.resolveTextEmbeddingModel(namespacedId, 'openai')
|
||||
for (const modelId of modelIds) {
|
||||
vi.clearAllMocks()
|
||||
registry.languageModel(`test-provider:${modelId}`)
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced embedding IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}cohere${DEFAULT_SEPARATOR}embed-multilingual`,
|
||||
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}embedding-model`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveTextEmbeddingModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveImageModel', () => {
|
||||
describe('Traditional Format', () => {
|
||||
it('should resolve traditional image model ID', async () => {
|
||||
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should resolve different image models', async () => {
|
||||
const testCases = [
|
||||
{ modelId: 'dall-e-2', providerId: 'openai' },
|
||||
{ modelId: 'stable-diffusion-xl', providerId: 'stability' },
|
||||
{ modelId: 'imagen-2', providerId: 'google' },
|
||||
{ modelId: 'midjourney-v6', providerId: 'midjourney' }
|
||||
]
|
||||
|
||||
for (const { modelId, providerId } of testCases) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveImageModel(modelId, providerId)
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${modelId}`)
|
||||
}
|
||||
})
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith(modelId)
|
||||
}
|
||||
})
|
||||
|
||||
describe('Namespaced Format', () => {
|
||||
it('should resolve namespaced image model ID', async () => {
|
||||
const namespacedId = `aihubmix${DEFAULT_SEPARATOR}openai${DEFAULT_SEPARATOR}dall-e-3`
|
||||
|
||||
const result = await resolver.resolveImageModel(namespacedId, 'openai')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(namespacedId)
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should handle complex namespaced image IDs', async () => {
|
||||
const complexIds = [
|
||||
`hub${DEFAULT_SEPARATOR}stability${DEFAULT_SEPARATOR}sdxl-turbo`,
|
||||
`custom${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}image-gen-v2`
|
||||
]
|
||||
|
||||
for (const id of complexIds) {
|
||||
vi.clearAllMocks()
|
||||
await resolver.resolveImageModel(id, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(id)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases and Error Scenarios', () => {
|
||||
it('should handle empty model IDs', async () => {
|
||||
await resolver.resolveLanguageModel('', 'openai')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai|')
|
||||
})
|
||||
|
||||
it('should handle model IDs with multiple separators', async () => {
|
||||
const multiSeparatorId = `hub${DEFAULT_SEPARATOR}sub${DEFAULT_SEPARATOR}provider${DEFAULT_SEPARATOR}model`
|
||||
|
||||
await resolver.resolveLanguageModel(multiSeparatorId, 'fallback')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(multiSeparatorId)
|
||||
})
|
||||
|
||||
it('should handle model IDs with only separator', async () => {
|
||||
const onlySeparator = DEFAULT_SEPARATOR
|
||||
|
||||
await resolver.resolveLanguageModel(onlySeparator, 'provider')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(onlySeparator)
|
||||
})
|
||||
|
||||
it('should throw if globalRegistryManagement throws', async () => {
|
||||
const error = new Error('Model not found in registry')
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
|
||||
it('should throw if provider throws', () => {
|
||||
const error = new Error('Model not found')
|
||||
vi.mocked(mockProvider.languageModel).mockImplementation(() => {
|
||||
throw error
|
||||
})
|
||||
|
||||
await expect(resolver.resolveLanguageModel('invalid-model', 'openai')).rejects.toThrow(
|
||||
'Model not found in registry'
|
||||
)
|
||||
expect(() => registry.languageModel('test-provider:invalid-model')).toThrow('Model not found')
|
||||
})
|
||||
|
||||
it('should handle concurrent resolution requests', async () => {
|
||||
const promises = [
|
||||
resolver.resolveLanguageModel('gpt-4', 'openai'),
|
||||
resolver.resolveLanguageModel('claude-3', 'anthropic'),
|
||||
resolver.resolveLanguageModel('gemini-2.0', 'google')
|
||||
it('should handle concurrent resolution requests', () => {
|
||||
const results = [
|
||||
registry.languageModel('test-provider:gpt-4'),
|
||||
registry.languageModel('test-provider:claude-3'),
|
||||
registry.languageModel('test-provider:gemini-2.0')
|
||||
]
|
||||
|
||||
const results = await Promise.all(promises)
|
||||
|
||||
expect(results).toHaveLength(3)
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledTimes(3)
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should throw for unknown provider', () => {
|
||||
expect(() => registry.languageModel('unknown:gpt-4' as `${string}:${string}`)).toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('embeddingModel', () => {
|
||||
it('should resolve embedding model ID', () => {
|
||||
const result = registry.embeddingModel('test-provider:text-embedding-ada-002')
|
||||
|
||||
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-ada-002')
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should resolve different embedding models', () => {
|
||||
const modelIds = ['text-embedding-3-small', 'text-embedding-3-large', 'embed-english-v3.0', 'voyage-2']
|
||||
|
||||
for (const modelId of modelIds) {
|
||||
vi.clearAllMocks()
|
||||
registry.embeddingModel(`test-provider:${modelId}`)
|
||||
|
||||
expect(mockProvider.embeddingModel).toHaveBeenCalledWith(modelId)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('imageModel', () => {
|
||||
it('should resolve image model ID', () => {
|
||||
const result = registry.imageModel('test-provider:dall-e-3')
|
||||
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should resolve different image models', () => {
|
||||
const modelIds = ['dall-e-2', 'stable-diffusion-xl', 'imagen-2', 'grok-2-image']
|
||||
|
||||
for (const modelId of modelIds) {
|
||||
vi.clearAllMocks()
|
||||
registry.imageModel(`test-provider:${modelId}`)
|
||||
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith(modelId)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should return properly typed LanguageModelV3', async () => {
|
||||
const result = await resolver.resolveLanguageModel('gpt-4', 'openai')
|
||||
it('should return properly typed LanguageModelV3', () => {
|
||||
const result = registry.languageModel('test-provider:gpt-4')
|
||||
|
||||
// Type assertions
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
expect(result).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', async () => {
|
||||
const result = await resolver.resolveTextEmbeddingModel('text-embedding-ada-002', 'openai')
|
||||
it('should return properly typed EmbeddingModelV3', () => {
|
||||
const result = registry.embeddingModel('test-provider:text-embedding-ada-002')
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', async () => {
|
||||
const result = await resolver.resolveImageModel('dall-e-3', 'openai')
|
||||
it('should return properly typed ImageModelV3', () => {
|
||||
const result = registry.imageModel('test-provider:dall-e-3')
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Global ModelResolver Instance', () => {
|
||||
it('should have a global instance available', async () => {
|
||||
const { globalModelResolver } = await import('../ModelResolver')
|
||||
describe('Multi-provider registry', () => {
|
||||
it('should route to correct provider in multi-provider registry', () => {
|
||||
const mockProvider2 = createMockProviderV3({
|
||||
provider: 'second-provider',
|
||||
languageModel: vi.fn(() =>
|
||||
createMockLanguageModel({
|
||||
provider: 'second-provider',
|
||||
modelId: 'other-model'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
expect(globalModelResolver).toBeInstanceOf(ModelResolver)
|
||||
const multiRegistry = createProviderRegistry({
|
||||
first: mockProvider,
|
||||
second: mockProvider2
|
||||
})
|
||||
|
||||
multiRegistry.languageModel('first:gpt-4')
|
||||
multiRegistry.languageModel('second:other-model')
|
||||
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockProvider2.languageModel).toHaveBeenCalledWith('other-model')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Integration with Different Provider Types', () => {
|
||||
it('should work with OpenAI compatible providers', async () => {
|
||||
await resolver.resolveLanguageModel('custom-model', 'openai-compatible')
|
||||
describe('All model types for same provider', () => {
|
||||
it('should handle all model types correctly', () => {
|
||||
registry.languageModel('test-provider:gpt-4')
|
||||
registry.embeddingModel('test-provider:text-embedding-3-small')
|
||||
registry.imageModel('test-provider:dall-e-3')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('openai-compatible|custom-model')
|
||||
})
|
||||
|
||||
it('should work with hub providers', async () => {
|
||||
const hubId = `aihubmix${DEFAULT_SEPARATOR}custom${DEFAULT_SEPARATOR}model-v1`
|
||||
|
||||
await resolver.resolveLanguageModel(hubId, 'aihubmix')
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(hubId)
|
||||
})
|
||||
|
||||
it('should handle all model types for same provider', async () => {
|
||||
const providerId = 'openai'
|
||||
const languageModel = 'gpt-4'
|
||||
const embeddingModel = 'text-embedding-3-small'
|
||||
const imageModel = 'dall-e-3'
|
||||
|
||||
await resolver.resolveLanguageModel(languageModel, providerId)
|
||||
await resolver.resolveTextEmbeddingModel(embeddingModel, providerId)
|
||||
await resolver.resolveImageModel(imageModel, providerId)
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith(`${providerId}|${languageModel}`)
|
||||
expect(globalRegistryManagement.embeddingModel).toHaveBeenCalledWith(`${providerId}|${embeddingModel}`)
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith(`${providerId}|${imageModel}`)
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
* Models 模块统一导出 - 简化版
|
||||
*/
|
||||
|
||||
// 核心模型解析器
|
||||
export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||
|
||||
// 保留的类型定义(可能被其他地方使用)
|
||||
export type { ModelConfig as ModelConfigType } from './types'
|
||||
|
||||
|
||||
@@ -3,12 +3,21 @@
|
||||
*/
|
||||
import type { JSONObject, LanguageModelV3Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||
import type { CoreProviderSettingsMap, ProviderId } from '../providers/types'
|
||||
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
/**
|
||||
* 模型配置
|
||||
*
|
||||
* @typeParam T - Provider ID 类型
|
||||
* @typeParam TSettingsMap - Provider Settings Map(默认 CoreProviderSettingsMap)
|
||||
*/
|
||||
export interface ModelConfig<
|
||||
T extends ProviderId = ProviderId,
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap
|
||||
> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
providerSettings: TSettingsMap[T & keyof TSettingsMap]
|
||||
middlewares?: LanguageModelV3Middleware[]
|
||||
extraModelConfig?: JSONObject
|
||||
}
|
||||
|
||||
@@ -1,109 +1,87 @@
|
||||
import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { createOpenAIOptions, createOpenRouterOptions, mergeProviderOptions } from '../factory'
|
||||
import { mergeProviderOptions } from '../factory'
|
||||
import type { TypedProviderOptions } from '../types'
|
||||
|
||||
// Helper to build typed options for tests without verbose casts at each call site
|
||||
const opts = (o: Record<string, Record<string, unknown>>): Partial<TypedProviderOptions> =>
|
||||
o as Partial<TypedProviderOptions>
|
||||
|
||||
describe('mergeProviderOptions', () => {
|
||||
it('deep merges provider options for the same provider', () => {
|
||||
const reasoningOptions = createOpenRouterOptions({
|
||||
reasoning: {
|
||||
enabled: true,
|
||||
effort: 'medium'
|
||||
}
|
||||
})
|
||||
const webSearchOptions = createOpenRouterOptions({
|
||||
plugins: [{ id: 'web', max_results: 5 }]
|
||||
})
|
||||
const reasoningOptions: Partial<TypedProviderOptions> = {
|
||||
openrouter: { reasoning: { enabled: true, effort: 'medium' } } as OpenRouterProviderOptions
|
||||
}
|
||||
const webSearchOptions = opts({ openrouter: { plugins: [{ id: 'web', max_results: 5 }] } })
|
||||
|
||||
const merged = mergeProviderOptions(reasoningOptions, webSearchOptions)
|
||||
|
||||
expect(merged.openrouter).toEqual({
|
||||
reasoning: {
|
||||
enabled: true,
|
||||
effort: 'medium'
|
||||
},
|
||||
reasoning: { enabled: true, effort: 'medium' },
|
||||
plugins: [{ id: 'web', max_results: 5 }]
|
||||
})
|
||||
})
|
||||
|
||||
it('preserves options from other providers while merging', () => {
|
||||
const openRouter = createOpenRouterOptions({
|
||||
reasoning: { enabled: true }
|
||||
})
|
||||
const openAI = createOpenAIOptions({
|
||||
reasoningEffort: 'low'
|
||||
})
|
||||
const openRouter: Partial<TypedProviderOptions> = {
|
||||
openrouter: { reasoning: { enabled: true, effort: 'medium' } } as OpenRouterProviderOptions
|
||||
}
|
||||
const openAI: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'low' } }
|
||||
const merged = mergeProviderOptions(openRouter, openAI)
|
||||
|
||||
expect(merged.openrouter).toEqual({ reasoning: { enabled: true } })
|
||||
expect(merged.openrouter).toEqual({ reasoning: { enabled: true, effort: 'medium' } })
|
||||
expect(merged.openai).toEqual({ reasoningEffort: 'low' })
|
||||
})
|
||||
|
||||
it('overwrites primitive values with later values', () => {
|
||||
const first = createOpenAIOptions({
|
||||
reasoningEffort: 'low',
|
||||
user: 'user-123'
|
||||
})
|
||||
const second = createOpenAIOptions({
|
||||
reasoningEffort: 'high',
|
||||
maxToolCalls: 5
|
||||
})
|
||||
const first: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'low', user: 'user-123' } }
|
||||
const second: Partial<TypedProviderOptions> = { openai: { reasoningEffort: 'high', maxToolCalls: 5 } }
|
||||
|
||||
const merged = mergeProviderOptions(first, second)
|
||||
|
||||
expect(merged.openai).toEqual({
|
||||
reasoningEffort: 'high', // overwritten by second
|
||||
user: 'user-123', // preserved from first
|
||||
maxToolCalls: 5 // added from second
|
||||
reasoningEffort: 'high',
|
||||
user: 'user-123',
|
||||
maxToolCalls: 5
|
||||
})
|
||||
})
|
||||
|
||||
it('overwrites arrays with later values instead of merging', () => {
|
||||
const first = createOpenRouterOptions({
|
||||
models: ['gpt-4', 'gpt-3.5-turbo']
|
||||
})
|
||||
const second = createOpenRouterOptions({
|
||||
models: ['claude-3-opus', 'claude-3-sonnet']
|
||||
})
|
||||
const first = opts({ openrouter: { models: ['gpt-4', 'gpt-3.5-turbo'] } })
|
||||
const second = opts({ openrouter: { models: ['claude-3-opus', 'claude-3-sonnet'] } })
|
||||
|
||||
const merged = mergeProviderOptions(first, second)
|
||||
|
||||
// Array is completely replaced, not merged
|
||||
expect(merged.openrouter?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
|
||||
expect((merged.openrouter as Record<string, unknown>)?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
|
||||
})
|
||||
|
||||
it('deeply merges nested objects while overwriting primitives', () => {
|
||||
const first = createOpenRouterOptions({
|
||||
reasoning: {
|
||||
enabled: true,
|
||||
effort: 'low'
|
||||
},
|
||||
user: 'user-123'
|
||||
const first = opts({
|
||||
openrouter: {
|
||||
reasoning: { enabled: true, effort: 'low' },
|
||||
user: 'user-123'
|
||||
}
|
||||
})
|
||||
const second = createOpenRouterOptions({
|
||||
reasoning: {
|
||||
effort: 'high',
|
||||
max_tokens: 500
|
||||
},
|
||||
user: 'user-456'
|
||||
const second = opts({
|
||||
openrouter: {
|
||||
reasoning: { effort: 'high', max_tokens: 500 },
|
||||
user: 'user-456'
|
||||
}
|
||||
})
|
||||
|
||||
const merged = mergeProviderOptions(first, second)
|
||||
|
||||
expect(merged.openrouter).toEqual({
|
||||
reasoning: {
|
||||
enabled: true, // preserved from first
|
||||
effort: 'high', // overwritten by second
|
||||
max_tokens: 500 // added from second
|
||||
},
|
||||
user: 'user-456' // overwritten by second
|
||||
reasoning: { enabled: true, effort: 'high', max_tokens: 500 },
|
||||
user: 'user-456'
|
||||
})
|
||||
})
|
||||
|
||||
it('replaces arrays instead of merging them', () => {
|
||||
const first = createOpenRouterOptions({ plugins: [{ id: 'old' }] })
|
||||
const second = createOpenRouterOptions({ plugins: [{ id: 'new' }] })
|
||||
const first = opts({ openrouter: { plugins: [{ id: 'old' }] } })
|
||||
const second = opts({ openrouter: { plugins: [{ id: 'new' }] } })
|
||||
const merged = mergeProviderOptions(first, second)
|
||||
// @ts-expect-error type-check for openrouter options is skipped. see function signature of createOpenRouterOptions
|
||||
expect(merged.openrouter?.plugins).toEqual([{ id: 'new' }])
|
||||
expect((merged.openrouter as Record<string, unknown>)?.plugins).toEqual([{ id: 'new' }])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,30 +1,4 @@
|
||||
import type { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||
|
||||
/**
|
||||
* 创建特定供应商的选项
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商特定的选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||
provider: T,
|
||||
options: ExtractProviderOptions<T>
|
||||
): Record<T, ExtractProviderOptions<T>> {
|
||||
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任意供应商的选项(包括未知供应商)
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createGenericProviderOptions<T extends string>(
|
||||
provider: T,
|
||||
options: Record<string, any>
|
||||
): Record<T, Record<string, any>> {
|
||||
return { [provider]: options } as Record<T, Record<string, any>>
|
||||
}
|
||||
import type { TypedProviderOptions } from './types'
|
||||
|
||||
type PlainObject = Record<string, any>
|
||||
|
||||
@@ -86,31 +60,3 @@ export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions
|
||||
return acc
|
||||
}, {} as TypedProviderOptions)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||
return createProviderOptions('openai', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Anthropic供应商选项的便捷函数
|
||||
*/
|
||||
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||
return createProviderOptions('anthropic', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Google供应商选项的便捷函数
|
||||
*/
|
||||
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
return createProviderOptions('google', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenRouter供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'> | Record<string, any>) {
|
||||
return createProviderOptions('openrouter', options)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,11 @@ import { type SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import { type XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import { type OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider'
|
||||
|
||||
export type ProviderOptions<T extends keyof SharedV3ProviderMetadata> = SharedV3ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
* Known provider options map for type-safe providerOptions construction.
|
||||
* Providers not listed here accept arbitrary Record<string, any>.
|
||||
*/
|
||||
export type ProviderOptionsMap = {
|
||||
type ProviderOptionsMap = {
|
||||
openai: OpenAIResponsesProviderOptions
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
@@ -18,12 +17,9 @@ export type ProviderOptionsMap = {
|
||||
xai: XaiProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||
|
||||
/**
|
||||
* 类型安全的ProviderOptions
|
||||
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||
* Type-safe ProviderOptions.
|
||||
* Known providers use strict types; unknown providers allow Record<string, any>.
|
||||
*/
|
||||
export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
# AI Core 插件系统
|
||||
|
||||
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。
|
||||
|
||||
## 🎯 设计理念
|
||||
|
||||
- **语义清晰**:不同钩子有不同的执行语义
|
||||
- **类型安全**:TypeScript 完整支持
|
||||
- **性能优化**:First 短路、Parallel 并发、Sequential 链式
|
||||
- **易于扩展**:`enforce` 排序 + 功能分类
|
||||
|
||||
## 📋 钩子类型
|
||||
|
||||
### 🥇 First 钩子 - 首个有效结果
|
||||
|
||||
```typescript
|
||||
// 只执行第一个返回值的插件,用于解析和查找
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
|
||||
```
|
||||
|
||||
### 🔄 Sequential 钩子 - 链式数据转换
|
||||
|
||||
```typescript
|
||||
// 按顺序链式执行,每个插件可以修改数据
|
||||
transformParams?: (params: any, context: AiRequestContext) => any
|
||||
transformResult?: (result: any, context: AiRequestContext) => any
|
||||
```
|
||||
|
||||
### ⚡ Parallel 钩子 - 并行副作用
|
||||
|
||||
```typescript
|
||||
// 并发执行,用于日志、监控等副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void
|
||||
onError?: (error: Error, context: AiRequestContext) => void
|
||||
```
|
||||
|
||||
### 🌊 Stream 钩子 - 流处理
|
||||
|
||||
```typescript
|
||||
// 直接使用 AI SDK 的 TransformStream
|
||||
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const pluginManager = new PluginManager()
|
||||
|
||||
// 添加插件
|
||||
pluginManager.use({
|
||||
name: 'my-plugin',
|
||||
async transformParams(params, context) {
|
||||
return { ...params, temperature: 0.7 }
|
||||
}
|
||||
})
|
||||
|
||||
// 使用插件
|
||||
const context = createContext('openai', 'gpt-4', { messages: [] })
|
||||
const transformedParams = await pluginManager.executeSequential(
|
||||
'transformParams',
|
||||
{ messages: [{ role: 'user', content: 'Hello' }] },
|
||||
context
|
||||
)
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
```typescript
|
||||
import {
|
||||
PluginManager,
|
||||
ModelAliasPlugin,
|
||||
LoggingPlugin,
|
||||
ParamsValidationPlugin,
|
||||
createContext
|
||||
} from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const manager = new PluginManager([
|
||||
ModelAliasPlugin, // 模型别名解析
|
||||
ParamsValidationPlugin, // 参数验证
|
||||
LoggingPlugin // 日志记录
|
||||
])
|
||||
|
||||
// AI 请求流程
|
||||
async function aiRequest(providerId: string, modelId: string, params: any) {
|
||||
const context = createContext(providerId, modelId, params)
|
||||
|
||||
try {
|
||||
// 1. 【并行】触发请求开始事件
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 【首个】解析模型别名
|
||||
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
|
||||
context.modelId = resolvedModel || modelId
|
||||
|
||||
// 3. 【串行】转换请求参数
|
||||
const transformedParams = await manager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 【流处理】收集流转换器(AI SDK 原生支持数组)
|
||||
const streamTransforms = manager.collectStreamTransforms()
|
||||
|
||||
// 5. 调用 AI SDK(这里省略具体实现)
|
||||
const result = await callAiSdk(transformedParams, streamTransforms)
|
||||
|
||||
// 6. 【串行】转换响应结果
|
||||
const transformedResult = await manager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 7. 【并行】触发请求完成事件
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 8. 【并行】触发错误事件
|
||||
await manager.executeParallel('onError', context, undefined, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 自定义插件
|
||||
|
||||
### 模型别名插件
|
||||
|
||||
```typescript
|
||||
const ModelAliasPlugin = definePlugin({
|
||||
name: 'model-alias',
|
||||
enforce: 'pre', // 最先执行
|
||||
|
||||
async resolveModel(modelId) {
|
||||
const aliases = {
|
||||
gpt4: 'gpt-4-turbo-preview',
|
||||
claude: 'claude-3-sonnet-20240229'
|
||||
}
|
||||
return aliases[modelId] || null
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 参数验证插件
|
||||
|
||||
```typescript
|
||||
const ValidationPlugin = definePlugin({
|
||||
name: 'validation',
|
||||
|
||||
async transformParams(params) {
|
||||
if (!params.messages) {
|
||||
throw new Error('messages is required')
|
||||
}
|
||||
|
||||
return {
|
||||
...params,
|
||||
temperature: params.temperature ?? 0.7,
|
||||
max_tokens: params.max_tokens ?? 4096
|
||||
}
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 监控插件
|
||||
|
||||
```typescript
|
||||
const MonitoringPlugin = definePlugin({
|
||||
name: 'monitoring',
|
||||
enforce: 'post', // 最后执行
|
||||
|
||||
async onRequestEnd(context, result) {
|
||||
const duration = Date.now() - context.startTime
|
||||
console.log(`请求耗时: ${duration}ms`)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 内容过滤插件
|
||||
|
||||
```typescript
|
||||
const FilterPlugin = definePlugin({
|
||||
name: 'content-filter',
|
||||
|
||||
transformStream() {
|
||||
return () =>
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'text-delta') {
|
||||
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
|
||||
controller.enqueue({ ...chunk, textDelta: filtered })
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## 📊 执行顺序
|
||||
|
||||
### 插件排序
|
||||
|
||||
```
|
||||
enforce: 'pre' → normal → enforce: 'post'
|
||||
```
|
||||
|
||||
### 钩子执行流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[请求开始] --> B[onRequestStart 并行执行]
|
||||
B --> C[resolveModel 首个有效]
|
||||
C --> D[loadTemplate 首个有效]
|
||||
D --> E[transformParams 串行执行]
|
||||
E --> F[collectStreamTransforms]
|
||||
F --> G[AI SDK 调用]
|
||||
G --> H[transformResult 串行执行]
|
||||
H --> I[onRequestEnd 并行执行]
|
||||
|
||||
G --> J[异常处理]
|
||||
J --> K[onError 并行执行]
|
||||
```
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
1. **功能单一**:每个插件专注一个功能
|
||||
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
|
||||
3. **错误处理**:插件内部处理异常,不要让异常向上传播
|
||||
4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel)
|
||||
5. **命名规范**:使用语义化的插件名称
|
||||
|
||||
## 🔍 调试工具
|
||||
|
||||
```typescript
|
||||
// 查看插件统计信息
|
||||
const stats = manager.getStats()
|
||||
console.log('插件统计:', stats)
|
||||
|
||||
// 查看所有插件
|
||||
const plugins = manager.getPlugins()
|
||||
console.log(
|
||||
'已注册插件:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
```
|
||||
|
||||
## ⚡ 性能优势
|
||||
|
||||
- **First 钩子**:一旦找到结果立即停止,避免无效计算
|
||||
- **Parallel 钩子**:真正并发执行,不阻塞主流程
|
||||
- **Sequential 钩子**:保证数据转换的顺序性
|
||||
- **Stream 钩子**:直接集成 AI SDK,零开销
|
||||
|
||||
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。
|
||||
@@ -1,51 +0,0 @@
|
||||
import { google } from '@ai-sdk/google'
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
import { type AiPlugin, definePlugin, type StreamTextParams, type StreamTextResult } from '../../'
|
||||
|
||||
const toolNameMap = {
|
||||
googleSearch: 'google_search',
|
||||
urlContext: 'url_context',
|
||||
codeExecution: 'code_execution'
|
||||
} as const
|
||||
|
||||
type ToolConfigKey = keyof typeof toolNameMap
|
||||
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
|
||||
|
||||
export const googleToolsPlugin = (config?: ToolConfig): AiPlugin<StreamTextParams, StreamTextResult> =>
|
||||
definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'googleToolsPlugin',
|
||||
transformParams: (params, context): Partial<StreamTextParams> => {
|
||||
const { providerId } = context
|
||||
|
||||
// 只在 Google provider 且有配置时才修改参数
|
||||
if (providerId !== 'google' || !config) {
|
||||
return {} // 返回空 Partial,表示不修改
|
||||
}
|
||||
|
||||
if (typeof params !== 'object' || params === null) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// 构建 tools 对象,确保类型兼容
|
||||
const hasTools = (Object.keys(config) as ToolConfigKey[]).some(
|
||||
(key) => config[key] && key in toolNameMap && key in google.tools
|
||||
)
|
||||
|
||||
if (!hasTools) {
|
||||
return {} // 返回空 Partial,表示不修改
|
||||
}
|
||||
|
||||
// 构建符合 AI SDK 的 tools 对象
|
||||
const tools: ToolSet = {}
|
||||
|
||||
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
|
||||
if (config[key] && key in toolNameMap && key in google.tools) {
|
||||
const toolName = toolNameMap[key]
|
||||
tools[toolName] = google.tools[key]({}) as ToolSet[string]
|
||||
}
|
||||
})
|
||||
|
||||
return { tools: { ...params.tools, ...tools } }
|
||||
}
|
||||
})
|
||||
@@ -1,10 +1,4 @@
|
||||
/**
|
||||
* 内置插件命名空间
|
||||
* 所有内置插件都以 'built-in:' 为前缀
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export * from './googleToolsPlugin'
|
||||
export * from './providerToolPlugin'
|
||||
export * from './toolUsePlugin/promptToolUsePlugin'
|
||||
export * from './toolUsePlugin/type'
|
||||
export * from './webSearchPlugin'
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
/**
|
||||
* 内置插件:日志记录
|
||||
* 记录AI调用的关键信息,支持性能监控和调试
|
||||
*/
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
|
||||
export interface LoggingConfig {
|
||||
// 日志级别
|
||||
level?: 'debug' | 'info' | 'warn' | 'error'
|
||||
// 是否记录参数
|
||||
logParams?: boolean
|
||||
// 是否记录结果
|
||||
logResult?: boolean
|
||||
// 是否记录性能数据
|
||||
logPerformance?: boolean
|
||||
// 自定义日志函数
|
||||
logger?: (level: string, message: string, data?: any) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建日志插件
|
||||
*/
|
||||
export function createLoggingPlugin(config: LoggingConfig = {}) {
|
||||
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
|
||||
|
||||
const startTimes = new Map<string, number>()
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:logging',
|
||||
|
||||
onRequestStart: (context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
startTimes.set(requestId, Date.now())
|
||||
|
||||
logger(level, `🚀 AI Request Started`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
originalParams: logParams ? context.originalParams : '[hidden]'
|
||||
})
|
||||
},
|
||||
|
||||
onRequestEnd: (context: AiRequestContext, result: any) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
const logData: any = {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId
|
||||
}
|
||||
|
||||
if (logPerformance && duration) {
|
||||
logData.duration = `${duration}ms`
|
||||
}
|
||||
|
||||
if (logResult) {
|
||||
logData.result = result
|
||||
}
|
||||
|
||||
logger(level, `✅ AI Request Completed`, logData)
|
||||
},
|
||||
|
||||
onError: (error: Error, context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
logger('error', `❌ AI Request Failed`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
duration: duration ? `${duration}ms` : undefined,
|
||||
error: {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* 通用 provider 工具注入插件
|
||||
*
|
||||
* 查找 extensionRegistry 中声明的 toolFactory,
|
||||
* 将返回的 ToolFactoryPatch(tools / providerOptions)合并到 params。
|
||||
*/
|
||||
|
||||
import { mergeProviderOptions } from '../../options'
|
||||
import { extensionRegistry } from '../../providers'
|
||||
import type { ToolCapability } from '../../providers/types/toolFactory'
|
||||
import { definePlugin } from '../'
|
||||
export const providerToolPlugin = (capability: ToolCapability, config: Record<string, any> = {}) =>
|
||||
definePlugin({
|
||||
name: capability,
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context) => {
|
||||
const { providerId } = context
|
||||
|
||||
const modelProvider =
|
||||
context.model && typeof context.model !== 'string' && 'provider' in context.model
|
||||
? context.model.provider
|
||||
: undefined
|
||||
|
||||
const resolved = await extensionRegistry.resolveToolCapability(providerId, capability, modelProvider)
|
||||
if (!resolved) return params
|
||||
|
||||
const userConfig = config[providerId] ?? {}
|
||||
const patch = resolved.factory(resolved.provider)(userConfig)
|
||||
|
||||
if (patch.tools) {
|
||||
params.tools = { ...params.tools, ...patch.tools }
|
||||
}
|
||||
if (patch.providerOptions) {
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, patch.providerOptions)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import { createMockContext, createMockTool } from '@test-utils'
|
||||
import type {
|
||||
EmbeddingModelUsage,
|
||||
ImageModelUsage,
|
||||
@@ -10,7 +11,6 @@ import type {
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockContext, createMockTool } from '../../../../../__tests__'
|
||||
import { StreamEventManager } from '../StreamEventManager'
|
||||
import type { StreamController } from '../ToolExecutor'
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { convertReadableStreamToArray } from 'ai/test'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '../../../../../__tests__'
|
||||
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
|
||||
|
||||
describe('promptToolUsePlugin', () => {
|
||||
|
||||
@@ -348,7 +348,7 @@ export const createPromptToolUsePlugin = (
|
||||
return new TransformStream()
|
||||
}
|
||||
|
||||
// 从 context 中获取或初始化 usage 累加器
|
||||
// 初始化 usage 累加器和工具执行状态
|
||||
if (!context.accumulatedUsage) {
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 0,
|
||||
@@ -358,17 +358,15 @@ export const createPromptToolUsePlugin = (
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
}
|
||||
if (context.hasExecutedToolsInCurrentStep === undefined) {
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
}
|
||||
|
||||
// 创建工具执行器、流事件管理器和标签提取器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
// 在context中初始化工具执行状态,避免递归调用时状态丢失
|
||||
if (!context.hasExecutedToolsInCurrentStep) {
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
}
|
||||
|
||||
// 用于hold text-start事件,直到确认有非工具标签内容
|
||||
let pendingTextStart: TextStreamPart<TOOLS> | null = null
|
||||
let hasStartedText = false
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
import { xai } from '@ai-sdk/xai'
|
||||
import type { InferToolInput, InferToolOutput } from 'ai'
|
||||
import { type Tool } from 'ai'
|
||||
|
||||
import { createOpenRouterOptions, mergeProviderOptions } from '../../../options'
|
||||
import type { AiRequestContext } from '../../'
|
||||
import type { OpenRouterSearchConfig } from './openrouter'
|
||||
|
||||
/**
|
||||
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||
*/
|
||||
export type OpenAISearchConfig = NonNullable<Parameters<typeof openai.tools.webSearch>[0]>
|
||||
export type OpenAISearchPreviewConfig = NonNullable<Parameters<typeof openai.tools.webSearchPreview>[0]>
|
||||
export type AnthropicSearchConfig = NonNullable<Parameters<typeof anthropic.tools.webSearch_20250305>[0]>
|
||||
export type GoogleSearchConfig = NonNullable<Parameters<typeof google.tools.googleSearch>[0]>
|
||||
export type XAIWebSearchConfig = NonNullable<Parameters<typeof xai.tools.webSearch>[0]>
|
||||
export type XAIXSearchConfig = NonNullable<Parameters<typeof xai.tools.xSearch>[0]>
|
||||
|
||||
type NormalizeTool<T> = T extends Tool<infer INPUT, infer OUTPUT> ? Tool<INPUT, OUTPUT> : Tool<any, any>
|
||||
|
||||
type AnthropicWebSearchTool = NormalizeTool<ReturnType<typeof anthropic.tools.webSearch_20250305>>
|
||||
type OpenAIWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearch>>
|
||||
type OpenAIChatWebSearchTool = NormalizeTool<ReturnType<typeof openai.tools.webSearchPreview>>
|
||||
type GoogleWebSearchTool = NormalizeTool<ReturnType<typeof google.tools.googleSearch>>
|
||||
type XAIWebSearchTool = NormalizeTool<ReturnType<typeof xai.tools.webSearch>>
|
||||
type XAIXSearchTool = NormalizeTool<ReturnType<typeof xai.tools.xSearch>>
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||
*/
|
||||
export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
'openai-chat'?: OpenAISearchPreviewConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: XAIWebSearchConfig
|
||||
'xai-xsearch'?: XAIXSearchConfig
|
||||
google?: GoogleSearchConfig
|
||||
openrouter?: OpenRouterSearchConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {},
|
||||
openai: {},
|
||||
'openai-chat': {},
|
||||
xai: {
|
||||
enableImageUnderstanding: true
|
||||
},
|
||||
'xai-xsearch': {
|
||||
enableImageUnderstanding: true
|
||||
},
|
||||
anthropic: {
|
||||
maxUses: 5
|
||||
},
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: 5
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = {
|
||||
// Anthropic 工具 - 手动定义
|
||||
anthropic: InferToolOutput<AnthropicWebSearchTool>
|
||||
|
||||
// OpenAI 工具 - 基于实际输出
|
||||
// TODO: 上游定义不规范,是unknown
|
||||
// openai: InferToolOutput<ReturnType<typeof openai.tools.webSearch>>
|
||||
openai: {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
'openai-chat': {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
// Google 工具
|
||||
// TODO: 上游定义不规范,是unknown
|
||||
// google: InferToolOutput<ReturnType<typeof google.tools.googleSearch>>
|
||||
google: {
|
||||
webSearchQueries?: string[]
|
||||
groundingChunks?: Array<{
|
||||
web?: { uri: string; title: string }
|
||||
}>
|
||||
}
|
||||
// xAI 工具
|
||||
xai: InferToolOutput<XAIWebSearchTool>
|
||||
'xai-xsearch': InferToolOutput<XAIXSearchTool>
|
||||
}
|
||||
|
||||
export type WebSearchToolInputSchema = {
|
||||
anthropic: InferToolInput<AnthropicWebSearchTool>
|
||||
openai: InferToolInput<OpenAIWebSearchTool>
|
||||
google: InferToolInput<GoogleWebSearchTool>
|
||||
'openai-chat': InferToolInput<OpenAIChatWebSearchTool>
|
||||
xai: InferToolInput<XAIWebSearchTool>
|
||||
'xai-xsearch': InferToolInput<XAIXSearchTool>
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to ensure params.tools object exists
|
||||
*/
|
||||
const ensureToolsObject = (params: any) => {
|
||||
if (!params.tools) params.tools = {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to apply tool-based web search configuration
|
||||
*/
|
||||
const applyToolBasedSearch = (params: any, toolName: string, toolInstance: any) => {
|
||||
ensureToolsObject(params)
|
||||
params.tools[toolName] = toolInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to apply provider options-based web search configuration
|
||||
*/
|
||||
const applyProviderOptionsSearch = (params: any, searchOptions: any) => {
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
|
||||
export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any, context?: AiRequestContext) => {
|
||||
const providerId = context?.providerId
|
||||
|
||||
// Provider-specific configuration map
|
||||
const providerHandlers: Record<string, () => void> = {
|
||||
openai: () => {
|
||||
const cfg = config.openai ?? DEFAULT_WEB_SEARCH_CONFIG.openai
|
||||
applyToolBasedSearch(params, 'web_search', openai.tools.webSearch(cfg))
|
||||
},
|
||||
'openai-chat': () => {
|
||||
const cfg = (config['openai-chat'] ?? DEFAULT_WEB_SEARCH_CONFIG['openai-chat']) as OpenAISearchPreviewConfig
|
||||
applyToolBasedSearch(params, 'web_search_preview', openai.tools.webSearchPreview(cfg))
|
||||
},
|
||||
anthropic: () => {
|
||||
const cfg = config.anthropic ?? DEFAULT_WEB_SEARCH_CONFIG.anthropic
|
||||
applyToolBasedSearch(params, 'web_search', anthropic.tools.webSearch_20250305(cfg))
|
||||
},
|
||||
google: () => {
|
||||
const cfg = (config.google ?? DEFAULT_WEB_SEARCH_CONFIG.google) as GoogleSearchConfig
|
||||
applyToolBasedSearch(params, 'web_search', google.tools.googleSearch(cfg))
|
||||
},
|
||||
xai: () => {
|
||||
const cfg = config.xai ?? DEFAULT_WEB_SEARCH_CONFIG.xai
|
||||
applyToolBasedSearch(params, 'web_search', xai.tools.webSearch(cfg))
|
||||
const xSearchCfg = config['xai-xsearch'] ?? DEFAULT_WEB_SEARCH_CONFIG['xai-xsearch']
|
||||
applyToolBasedSearch(params, 'x_search', xai.tools.xSearch(xSearchCfg))
|
||||
},
|
||||
openrouter: () => {
|
||||
const cfg = (config.openrouter ?? DEFAULT_WEB_SEARCH_CONFIG.openrouter) as OpenRouterSearchConfig
|
||||
const searchOptions = createOpenRouterOptions(cfg)
|
||||
applyProviderOptionsSearch(params, searchOptions)
|
||||
}
|
||||
}
|
||||
|
||||
// Try provider-specific handler first
|
||||
const handler = providerId && providerHandlers[providerId]
|
||||
if (handler) {
|
||||
handler()
|
||||
return params
|
||||
}
|
||||
|
||||
// Fallback: apply based on available config keys (prioritized order)
|
||||
const fallbackOrder: Array<keyof WebSearchPluginConfig> = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'openrouter'
|
||||
]
|
||||
|
||||
for (const key of fallbackOrder) {
|
||||
if (config[key]) {
|
||||
providerHandlers[key]()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
@@ -1,44 +1,40 @@
|
||||
/**
|
||||
* Web Search Plugin
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import type { WebSearchToolConfigMap } from '../../../providers'
|
||||
|
||||
import { definePlugin } from '../../'
|
||||
import type { WebSearchPluginConfig } from './helper'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper'
|
||||
export type OpenRouterSearchConfig = {
|
||||
plugins?: Array<{
|
||||
id: 'web'
|
||||
/**
|
||||
* Maximum number of search results to include (default: 5)
|
||||
*/
|
||||
max_results?: number
|
||||
/**
|
||||
* Custom search prompt to guide the search query
|
||||
*/
|
||||
search_prompt?: string
|
||||
}>
|
||||
/**
|
||||
* Built-in web search options for models that support native web search
|
||||
*/
|
||||
web_search_options?: {
|
||||
/**
|
||||
* Maximum number of search results to include
|
||||
*/
|
||||
max_results?: number
|
||||
/**
|
||||
* Custom search prompt to guide the search query
|
||||
*/
|
||||
search_prompt?: string
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* @param config - 在插件初始化时传入的静态配置
|
||||
* key = provider ID,value = 该 provider 的搜索配置
|
||||
*
|
||||
* - 大部分类型从 coreExtensions 的 toolFactories 声明中自动提取(WebSearchToolConfigMap)
|
||||
* - OpenRouter 使用自定义配置(非 SDK .tools 模式),从 openrouter.ts 导入
|
||||
*/
|
||||
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context) => {
|
||||
let { providerId } = context
|
||||
|
||||
// For cherryin providers, extract the actual provider from the model's provider string
|
||||
// Expected format: "cherryin.{actualProvider}" (e.g., "cherryin.gemini")
|
||||
if (providerId === 'cherryin' || providerId === 'cherryin-chat') {
|
||||
const provider = params.model?.provider
|
||||
if (provider && typeof provider === 'string' && provider.includes('.')) {
|
||||
const extractedProviderId = provider.split('.')[1]
|
||||
if (extractedProviderId) {
|
||||
providerId = extractedProviderId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switchWebSearchTool(config, params, { ...context, providerId })
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export * from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
export type WebSearchPluginConfig = WebSearchToolConfigMap & {
|
||||
openrouter?: OpenRouterSearchConfig
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
export type OpenRouterSearchConfig = {
|
||||
plugins?: Array<{
|
||||
id: 'web'
|
||||
/**
|
||||
* Maximum number of search results to include (default: 5)
|
||||
*/
|
||||
max_results?: number
|
||||
/**
|
||||
* Custom search prompt to guide the search query
|
||||
*/
|
||||
search_prompt?: string
|
||||
}>
|
||||
/**
|
||||
* Built-in web search options for models that support native web search
|
||||
*/
|
||||
web_search_options?: {
|
||||
/**
|
||||
* Maximum number of search results to include
|
||||
*/
|
||||
max_results?: number
|
||||
/**
|
||||
* Custom search prompt to guide the search query
|
||||
*/
|
||||
search_prompt?: string
|
||||
}
|
||||
}
|
||||
@@ -2,12 +2,8 @@
|
||||
export type {
|
||||
AiPlugin,
|
||||
AiRequestContext,
|
||||
AiRequestMetadata,
|
||||
GenerateTextParams,
|
||||
GenerateTextResult,
|
||||
HookResult,
|
||||
PluginManagerConfig,
|
||||
RecursiveCallFn,
|
||||
StreamTextParams,
|
||||
StreamTextResult
|
||||
} from './types'
|
||||
|
||||
@@ -23,7 +23,6 @@ export interface AiRequestMetadata {
|
||||
enableGenerateImage?: boolean
|
||||
isPromptToolUse?: boolean
|
||||
isSupportedToolUse?: boolean
|
||||
isImageGenerationEndpoint?: boolean
|
||||
// 自定义元数据,使用 JSONValue 确保类型安全
|
||||
custom?: JSONObject
|
||||
}
|
||||
@@ -32,7 +31,7 @@ export interface AiRequestMetadata {
|
||||
* 递归调用函数类型
|
||||
* 泛型化以保持类型推导
|
||||
*/
|
||||
export type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams: Partial<TParams>) => Promise<TResult>
|
||||
type RecursiveCallFn<TParams = unknown, TResult = unknown> = (newParams: Partial<TParams>) => Promise<TResult>
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
@@ -107,20 +106,3 @@ export interface AiPlugin<TParams = unknown, TResult = unknown> {
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件管理器配置
|
||||
*/
|
||||
export interface PluginManagerConfig<TParams = unknown, TResult = unknown> {
|
||||
plugins: AiPlugin<TParams, TResult>[]
|
||||
context: Partial<AiRequestContext<TParams, TResult>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子执行结果
|
||||
* 泛型参数指定返回值类型
|
||||
*/
|
||||
export interface HookResult<T = unknown> {
|
||||
value: T
|
||||
stop?: boolean
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
/**
|
||||
* Hub Provider - 支持路由到多个底层provider
|
||||
*
|
||||
* 支持格式: hubId:providerId:modelId
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import type {
|
||||
EmbeddingModelV3,
|
||||
ImageModelV3,
|
||||
LanguageModelV3,
|
||||
ProviderV3,
|
||||
RerankingModelV3,
|
||||
SpeechModelV3,
|
||||
TranscriptionModelV3
|
||||
} from '@ai-sdk/provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from './RegistryManagement'
|
||||
import type { AiSdkProvider } from './types'
|
||||
|
||||
export interface HubProviderConfig {
|
||||
/** Hub的唯一标识符 */
|
||||
hubId: string
|
||||
/** 是否启用调试日志 */
|
||||
debug?: boolean
|
||||
}
|
||||
|
||||
export class HubProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly hubId: string,
|
||||
public readonly providerId?: string,
|
||||
public readonly originalError?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'HubProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析Hub模型ID
|
||||
*/
|
||||
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||
const parts = modelId.split(DEFAULT_SEPARATOR)
|
||||
if (parts.length !== 2) {
|
||||
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||
}
|
||||
return {
|
||||
provider: parts[0],
|
||||
actualModelId: parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): AiSdkProvider {
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV3 {
|
||||
// 从全局注册表获取provider实例
|
||||
try {
|
||||
const provider = globalRegistryManagement.getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new HubProviderError(
|
||||
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||
hubId,
|
||||
providerId
|
||||
)
|
||||
}
|
||||
// 使用 wrapProvider 确保返回的是 V3 provider
|
||||
// 这样可以自动处理 V2 provider 到 V3 的转换
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
} catch (error) {
|
||||
throw new HubProviderError(
|
||||
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
hubId,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建符合 ProviderV3 规范的 fallback provider
|
||||
const hubFallbackProvider = {
|
||||
specificationVersion: 'v3' as const,
|
||||
|
||||
languageModel: (modelId: string): LanguageModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.languageModel(actualModelId)
|
||||
},
|
||||
|
||||
embeddingModel: (modelId: string): EmbeddingModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.embeddingModel(actualModelId)
|
||||
},
|
||||
|
||||
imageModel: (modelId: string): ImageModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
return targetProvider.imageModel(actualModelId)
|
||||
},
|
||||
|
||||
transcriptionModel: (modelId: string): TranscriptionModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.transcriptionModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.transcriptionModel(actualModelId)
|
||||
},
|
||||
|
||||
speechModel: (modelId: string): SpeechModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.speechModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.speechModel(actualModelId)
|
||||
},
|
||||
rerankingModel: (modelId: string): RerankingModelV3 => {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.rerankingModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support reranking models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.rerankingModel(actualModelId)
|
||||
}
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: hubFallbackProvider
|
||||
})
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
/**
|
||||
* Provider 注册表管理器
|
||||
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
|
||||
* 基于 AI SDK 原生的 createProviderRegistry
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
||||
|
||||
type PROVIDERS = Record<string, ProviderV3>
|
||||
|
||||
export const DEFAULT_SEPARATOR = '|'
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
private separator: SEPARATOR
|
||||
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||
|
||||
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
|
||||
this.separator = options.separator
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册已配置好的 provider 实例
|
||||
*/
|
||||
registerProvider(id: string, provider: ProviderV3, aliases?: string[]): this {
|
||||
// 注册主provider
|
||||
this.providers[id] = provider
|
||||
|
||||
// 注册别名(都指向同一个provider实例)
|
||||
if (aliases) {
|
||||
aliases.forEach((alias) => {
|
||||
this.providers[alias] = provider // 直接存储引用
|
||||
this.aliases.add(alias) // 标记为别名
|
||||
})
|
||||
}
|
||||
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的provider实例
|
||||
*/
|
||||
getProvider(id: string): ProviderV3 | undefined {
|
||||
return this.providers[id]
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册 providers
|
||||
*/
|
||||
registerProviders(providers: Record<string, ProviderV3>): this {
|
||||
Object.assign(this.providers, providers)
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除 provider(同时清理相关别名)
|
||||
*/
|
||||
unregisterProvider(id: string): this {
|
||||
const provider = this.providers[id]
|
||||
if (!provider) return this
|
||||
|
||||
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||
if (!this.aliases.has(id)) {
|
||||
// 找到所有指向此provider的别名并删除
|
||||
const aliasesToRemove: string[] = []
|
||||
this.aliases.forEach((alias) => {
|
||||
if (this.providers[alias] === provider) {
|
||||
aliasesToRemove.push(alias)
|
||||
}
|
||||
})
|
||||
|
||||
aliasesToRemove.forEach((alias) => {
|
||||
delete this.providers[alias]
|
||||
this.aliases.delete(alias)
|
||||
})
|
||||
} else {
|
||||
// 如果移除的是别名,只删除别名记录
|
||||
this.aliases.delete(id)
|
||||
}
|
||||
|
||||
delete this.providers[id]
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 立即重建 registry - 每次变更都重建
|
||||
*/
|
||||
private rebuildRegistry(): void {
|
||||
if (Object.keys(this.providers).length === 0) {
|
||||
this.registry = null
|
||||
return
|
||||
}
|
||||
|
||||
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
|
||||
separator: this.separator
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语言模型 - AI SDK 原生方法
|
||||
*/
|
||||
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV3 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.languageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文本嵌入模型 - AI SDK 原生方法
|
||||
*/
|
||||
embeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV3 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.embeddingModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取图像模型 - AI SDK 原生方法
|
||||
*/
|
||||
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV3 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.imageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取转录模型 - AI SDK 原生方法
|
||||
*/
|
||||
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.transcriptionModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语音模型 - AI SDK 原生方法
|
||||
*/
|
||||
speechModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.speechModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有已注册的 providers
|
||||
*/
|
||||
hasProviders(): boolean {
|
||||
return Object.keys(this.providers).length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有 providers
|
||||
*/
|
||||
clear(): this {
|
||||
this.providers = {}
|
||||
this.aliases.clear()
|
||||
this.registry = null
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||
* 如果传入的是别名,返回真实的Provider ID
|
||||
* 如果传入的是真实ID,直接返回
|
||||
*/
|
||||
resolveProviderId(id: string): string {
|
||||
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||
|
||||
// 是别名,找到真实ID
|
||||
const targetProvider = this.providers[id]
|
||||
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||
return realId
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为别名
|
||||
*/
|
||||
isAlias(id: string): boolean {
|
||||
return this.aliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
getAllAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
this.aliases.forEach((alias) => {
|
||||
result[alias] = this.resolveProviderId(alias)
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局注册表管理器实例
|
||||
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
|
||||
*/
|
||||
export const globalRegistryManagement = new RegistryManagement()
|
||||
@@ -0,0 +1,925 @@
|
||||
/**
|
||||
* ExtensionRegistry 单元测试
|
||||
*/
|
||||
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createMockProviderV3 } from '@test-utils'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ExtensionRegistry } from '../core/ExtensionRegistry'
|
||||
import { ProviderExtension } from '../core/ProviderExtension'
|
||||
import { ProviderCreationError } from '../core/utils'
|
||||
|
||||
describe('ExtensionRegistry', () => {
|
||||
let registry: ExtensionRegistry
|
||||
|
||||
beforeEach(() => {
|
||||
registry = new ExtensionRegistry()
|
||||
})
|
||||
|
||||
describe('register', () => {
|
||||
it('should register an extension', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
expect(registry.has('test-provider')).toBe(true)
|
||||
expect(registry.get('test-provider')).toBe(extension)
|
||||
})
|
||||
|
||||
it('should register aliases', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openrouter',
|
||||
aliases: ['or', 'open-router'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
expect(registry.has('openrouter')).toBe(true)
|
||||
expect(registry.has('or')).toBe(true)
|
||||
expect(registry.has('open-router')).toBe(true)
|
||||
|
||||
// 别名应该指向同一个 extension
|
||||
expect(registry.get('or')).toBe(extension)
|
||||
expect(registry.get('open-router')).toBe(extension)
|
||||
})
|
||||
|
||||
it('should be idempotent when name already registered', () => {
|
||||
const ext1 = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
const ext2 = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(ext1)
|
||||
registry.register(ext2) // should not throw
|
||||
|
||||
// original extension is preserved
|
||||
expect(registry.get('test-provider')).toBe(ext1)
|
||||
})
|
||||
|
||||
it('should throw error if alias already registered', () => {
|
||||
const ext1 = new ProviderExtension({
|
||||
name: 'provider1',
|
||||
aliases: ['shared-alias'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
const ext2 = new ProviderExtension({
|
||||
name: 'provider2',
|
||||
aliases: ['shared-alias'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(ext1)
|
||||
|
||||
expect(() => registry.register(ext2)).toThrow('already registered')
|
||||
})
|
||||
|
||||
it('should support method chaining', () => {
|
||||
const ext1 = new ProviderExtension({
|
||||
name: 'provider1',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
const ext2 = new ProviderExtension({
|
||||
name: 'provider2',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
const result = registry.register(ext1).register(ext2)
|
||||
|
||||
expect(result).toBe(registry)
|
||||
expect(registry.has('provider1')).toBe(true)
|
||||
expect(registry.has('provider2')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('registerAll', () => {
|
||||
it('should register multiple extensions', () => {
|
||||
const extensions = [
|
||||
new ProviderExtension({ name: 'provider1', create: createMockProviderV3 }),
|
||||
new ProviderExtension({ name: 'provider2', create: createMockProviderV3 }),
|
||||
new ProviderExtension({ name: 'provider3', create: createMockProviderV3 })
|
||||
]
|
||||
|
||||
registry.registerAll(extensions)
|
||||
|
||||
expect(registry.has('provider1')).toBe(true)
|
||||
expect(registry.has('provider2')).toBe(true)
|
||||
expect(registry.has('provider3')).toBe(true)
|
||||
})
|
||||
|
||||
it('should support method chaining', () => {
|
||||
const extensions = [new ProviderExtension({ name: 'test', create: createMockProviderV3 })]
|
||||
|
||||
const result = registry.registerAll(extensions)
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
})
|
||||
|
||||
describe('unregister', () => {
|
||||
it('should remove extension', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
expect(registry.has('test-provider')).toBe(true)
|
||||
|
||||
const result = registry.unregister('test-provider')
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(registry.has('test-provider')).toBe(false)
|
||||
})
|
||||
|
||||
it('should remove aliases', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
aliases: ['alias1', 'alias2'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
registry.unregister('test-provider')
|
||||
|
||||
expect(registry.has('alias1')).toBe(false)
|
||||
expect(registry.has('alias2')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false if extension not found', () => {
|
||||
const result = registry.unregister('non-existent')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('get', () => {
|
||||
it('should get extension by name', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
expect(registry.get('test-provider')).toBe(extension)
|
||||
})
|
||||
|
||||
it('should get extension by alias', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
aliases: ['test-alias'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
expect(registry.get('test-alias')).toBe(extension)
|
||||
})
|
||||
|
||||
it('should return undefined for non-existent ID', () => {
|
||||
expect(registry.get('non-existent')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAll', () => {
|
||||
it('should return all registered extensions', () => {
|
||||
const ext1 = new ProviderExtension({ name: 'provider1', create: createMockProviderV3 })
|
||||
const ext2 = new ProviderExtension({ name: 'provider2', create: createMockProviderV3 })
|
||||
|
||||
registry.register(ext1).register(ext2)
|
||||
|
||||
const all = registry.getAll()
|
||||
|
||||
expect(all).toHaveLength(2)
|
||||
expect(all).toContain(ext1)
|
||||
expect(all).toContain(ext2)
|
||||
})
|
||||
|
||||
it('should return empty array when no extensions registered', () => {
|
||||
expect(registry.getAll()).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAllProviderIds', () => {
|
||||
it('should return all provider IDs including variants', () => {
|
||||
const ext1 = new ProviderExtension({
|
||||
name: 'openai',
|
||||
aliases: ['oai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const ext2 = new ProviderExtension({
|
||||
name: 'azure',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
registry.register(ext1).register(ext2)
|
||||
|
||||
const ids = registry.getAllProviderIds()
|
||||
|
||||
expect(ids).toContain('openai')
|
||||
expect(ids).toContain('oai')
|
||||
expect(ids).toContain('openai-chat')
|
||||
expect(ids).toContain('azure')
|
||||
})
|
||||
})
|
||||
|
||||
describe('clear', () => {
|
||||
it('should remove all extensions', () => {
|
||||
registry.register(new ProviderExtension({ name: 'provider1', create: createMockProviderV3 }))
|
||||
registry.register(new ProviderExtension({ name: 'provider2', create: createMockProviderV3 }))
|
||||
|
||||
registry.clear()
|
||||
|
||||
expect(registry.getAll()).toEqual([])
|
||||
expect(registry.getAllProviderIds()).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('createProvider', () => {
|
||||
it('should create provider using create function', async () => {
|
||||
const mockProvider = createMockProviderV3()
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: () => mockProvider
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
const provider = await registry.createProvider('test-provider')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should merge default options with user settings', async () => {
|
||||
let receivedSettings: any
|
||||
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
defaultOptions: { apiKey: 'default-key', timeout: 5000 },
|
||||
create: ((settings: any) => {
|
||||
receivedSettings = settings
|
||||
return createMockProviderV3()
|
||||
}) as any
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
await registry.createProvider('test-provider', { baseURL: 'https://api.test.com' })
|
||||
|
||||
expect(receivedSettings).toEqual({
|
||||
apiKey: 'default-key',
|
||||
timeout: 5000,
|
||||
baseURL: 'https://api.test.com'
|
||||
})
|
||||
})
|
||||
|
||||
it('should create provider using dynamic import', async () => {
|
||||
const mockProvider = createMockProviderV3()
|
||||
|
||||
const extension = new ProviderExtension({
|
||||
name: 'lazy-provider',
|
||||
import: async () => ({
|
||||
createLazyProvider: () => mockProvider
|
||||
}),
|
||||
creatorFunctionName: 'createLazyProvider'
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
const provider = await registry.createProvider('lazy-provider')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should throw error if extension not found', async () => {
|
||||
await expect(registry.createProvider('non-existent')).rejects.toThrow('not found')
|
||||
})
|
||||
|
||||
it('should throw error if creator function not found in imported module', async () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'bad-import',
|
||||
import: async () => ({}),
|
||||
creatorFunctionName: 'nonExistentFunction'
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
try {
|
||||
await registry.createProvider('bad-import')
|
||||
expect.fail('Should have thrown')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(ProviderCreationError)
|
||||
expect((error as ProviderCreationError).cause.message).toContain('not found in imported module')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Caching', () => {
|
||||
it('should cache provider instances based on settings', async () => {
|
||||
const createSpy = vi.fn(createMockProviderV3)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createSpy
|
||||
})
|
||||
)
|
||||
|
||||
// First call - should create
|
||||
const provider1 = await registry.createProvider('test-provider', { apiKey: 'same-key' })
|
||||
expect(createSpy).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Second call with same settings - should use cache
|
||||
const provider2 = await registry.createProvider('test-provider', { apiKey: 'same-key' })
|
||||
expect(createSpy).toHaveBeenCalledTimes(1) // Still 1
|
||||
expect(provider2).toBe(provider1) // Same instance
|
||||
|
||||
// Third call with different settings - should create new
|
||||
const provider3 = await registry.createProvider('test-provider', { apiKey: 'different-key' })
|
||||
expect(createSpy).toHaveBeenCalledTimes(2)
|
||||
expect(provider3).not.toBe(provider1)
|
||||
})
|
||||
|
||||
it('should deep merge settings before generating cache key', async () => {
|
||||
let firstSettings: any
|
||||
let secondSettings: any
|
||||
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
defaultOptions: {
|
||||
apiKey: 'default-key',
|
||||
headers: { 'X-Default': 'value' }
|
||||
},
|
||||
create: (settings) => {
|
||||
if (!firstSettings) {
|
||||
firstSettings = settings
|
||||
} else {
|
||||
secondSettings = settings
|
||||
}
|
||||
return createMockProviderV3()
|
||||
}
|
||||
})
|
||||
|
||||
registry.register(extension)
|
||||
|
||||
await registry.createProvider('test-provider', { headers: { 'X-Custom': 'custom' } })
|
||||
await registry.createProvider('test-provider', { headers: { 'X-Custom': 'custom' } })
|
||||
|
||||
// Should use cache - only created once
|
||||
expect(secondSettings).toBeUndefined()
|
||||
|
||||
// Verify deep merge happened
|
||||
expect(firstSettings).toEqual({
|
||||
apiKey: 'default-key',
|
||||
headers: {
|
||||
'X-Default': 'value',
|
||||
'X-Custom': 'custom'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ProviderCreationError', () => {
|
||||
it('should wrap errors in ProviderCreationError', async () => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: () => {
|
||||
throw new Error('Creation failed')
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
try {
|
||||
await registry.createProvider('test-provider', { apiKey: 'key' })
|
||||
expect.fail('Should have thrown')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(ProviderCreationError)
|
||||
expect((error as ProviderCreationError).providerId).toBe('test-provider')
|
||||
expect((error as ProviderCreationError).cause.message).toBe('Creation failed')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('resolveProviderIdWithMode', () => {
|
||||
beforeEach(() => {
|
||||
// 注册带变体的 extension
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
aliases: ['oai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'azure',
|
||||
aliases: ['azure-openai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Azure Responses',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'google',
|
||||
aliases: ['gemini'],
|
||||
create: createMockProviderV3
|
||||
// 没有 variants
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve base ID + mode to variant ID', () => {
|
||||
expect(registry.resolveProviderIdWithMode('openai', 'chat')).toBe('openai-chat')
|
||||
expect(registry.resolveProviderIdWithMode('azure', 'responses')).toBe('azure-responses')
|
||||
})
|
||||
|
||||
it('should support aliases in base ID', () => {
|
||||
expect(registry.resolveProviderIdWithMode('oai', 'chat')).toBe('openai-chat')
|
||||
expect(registry.resolveProviderIdWithMode('azure-openai', 'responses')).toBe('azure-responses')
|
||||
})
|
||||
|
||||
it('should return null if extension has no matching variant', () => {
|
||||
expect(registry.resolveProviderIdWithMode('openai', 'responses')).toBeNull()
|
||||
expect(registry.resolveProviderIdWithMode('azure', 'chat')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null if extension has no variants at all', () => {
|
||||
expect(registry.resolveProviderIdWithMode('google', 'chat')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null if extension not found', () => {
|
||||
expect(registry.resolveProviderIdWithMode('non-existent', 'chat')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return resolved base ID when mode is not provided', () => {
|
||||
expect(registry.resolveProviderIdWithMode('openai')).toBe('openai')
|
||||
expect(registry.resolveProviderIdWithMode('oai')).toBe('openai')
|
||||
expect(registry.resolveProviderIdWithMode('gemini')).toBe('google')
|
||||
})
|
||||
|
||||
it('should return null when mode is not provided and extension not found', () => {
|
||||
expect(registry.resolveProviderIdWithMode('non-existent')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('parseProviderId', () => {
|
||||
beforeEach(() => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
aliases: ['oai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'azure',
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Azure Responses',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'google',
|
||||
aliases: ['gemini'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should parse variant ID to base ID + mode', () => {
|
||||
expect(registry.parseProviderId('openai-chat')).toEqual({
|
||||
baseId: 'openai',
|
||||
mode: 'chat',
|
||||
isVariant: true
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('azure-responses')).toEqual({
|
||||
baseId: 'azure',
|
||||
mode: 'responses',
|
||||
isVariant: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should parse base ID without mode', () => {
|
||||
expect(registry.parseProviderId('openai')).toEqual({
|
||||
baseId: 'openai',
|
||||
isVariant: false
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('azure')).toEqual({
|
||||
baseId: 'azure',
|
||||
isVariant: false
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('google')).toEqual({
|
||||
baseId: 'google',
|
||||
isVariant: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should resolve aliases to base ID', () => {
|
||||
expect(registry.parseProviderId('oai')).toEqual({
|
||||
baseId: 'openai',
|
||||
isVariant: false
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('gemini')).toEqual({
|
||||
baseId: 'google',
|
||||
isVariant: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should return null for unknown provider ID', () => {
|
||||
expect(registry.parseProviderId('non-existent')).toBeNull()
|
||||
expect(registry.parseProviderId('unknown-variant')).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle multiple variants of same extension', () => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'multi-variant',
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
},
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Responses',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
},
|
||||
{
|
||||
suffix: 'completions',
|
||||
name: 'Completions',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
expect(registry.parseProviderId('multi-variant-chat')).toEqual({
|
||||
baseId: 'multi-variant',
|
||||
mode: 'chat',
|
||||
isVariant: true
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('multi-variant-responses')).toEqual({
|
||||
baseId: 'multi-variant',
|
||||
mode: 'responses',
|
||||
isVariant: true
|
||||
})
|
||||
|
||||
expect(registry.parseProviderId('multi-variant-completions')).toEqual({
|
||||
baseId: 'multi-variant',
|
||||
mode: 'completions',
|
||||
isVariant: true
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Variant Query Methods', () => {
|
||||
beforeEach(() => {
|
||||
// 注册带变体的 extensions
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
aliases: ['oai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'azure',
|
||||
aliases: ['azure-openai'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Azure Responses',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'google',
|
||||
aliases: ['gemini'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Google Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'xai',
|
||||
create: createMockProviderV3
|
||||
// 没有 variants
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
describe('isVariant', () => {
|
||||
it('should return true for variant IDs', () => {
|
||||
expect(registry.isVariant('openai-chat')).toBe(true)
|
||||
expect(registry.isVariant('azure-responses')).toBe(true)
|
||||
expect(registry.isVariant('google-chat')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for base provider IDs', () => {
|
||||
expect(registry.isVariant('openai')).toBe(false)
|
||||
expect(registry.isVariant('azure')).toBe(false)
|
||||
expect(registry.isVariant('google')).toBe(false)
|
||||
expect(registry.isVariant('xai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for aliases', () => {
|
||||
expect(registry.isVariant('oai')).toBe(false)
|
||||
expect(registry.isVariant('gemini')).toBe(false)
|
||||
expect(registry.isVariant('azure-openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for unknown IDs', () => {
|
||||
expect(registry.isVariant('unknown')).toBe(false)
|
||||
expect(registry.isVariant('non-existent-variant')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getBaseProviderId', () => {
|
||||
it('should return base ID for variant IDs', () => {
|
||||
expect(registry.getBaseProviderId('openai-chat')).toBe('openai')
|
||||
expect(registry.getBaseProviderId('azure-responses')).toBe('azure')
|
||||
expect(registry.getBaseProviderId('google-chat')).toBe('google')
|
||||
})
|
||||
|
||||
it('should return base ID for base provider IDs (identity)', () => {
|
||||
expect(registry.getBaseProviderId('openai')).toBe('openai')
|
||||
expect(registry.getBaseProviderId('azure')).toBe('azure')
|
||||
expect(registry.getBaseProviderId('google')).toBe('google')
|
||||
expect(registry.getBaseProviderId('xai')).toBe('xai')
|
||||
})
|
||||
|
||||
it('should return base ID for aliases', () => {
|
||||
expect(registry.getBaseProviderId('oai')).toBe('openai')
|
||||
expect(registry.getBaseProviderId('gemini')).toBe('google')
|
||||
expect(registry.getBaseProviderId('azure-openai')).toBe('azure')
|
||||
})
|
||||
|
||||
it('should return null for unknown IDs', () => {
|
||||
expect(registry.getBaseProviderId('unknown')).toBeNull()
|
||||
expect(registry.getBaseProviderId('non-existent')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getVariantMode', () => {
|
||||
it('should return mode/suffix for variant IDs', () => {
|
||||
expect(registry.getVariantMode('openai-chat')).toBe('chat')
|
||||
expect(registry.getVariantMode('azure-responses')).toBe('responses')
|
||||
expect(registry.getVariantMode('google-chat')).toBe('chat')
|
||||
})
|
||||
|
||||
it('should return null for base provider IDs', () => {
|
||||
expect(registry.getVariantMode('openai')).toBeNull()
|
||||
expect(registry.getVariantMode('azure')).toBeNull()
|
||||
expect(registry.getVariantMode('google')).toBeNull()
|
||||
expect(registry.getVariantMode('xai')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for aliases', () => {
|
||||
expect(registry.getVariantMode('oai')).toBeNull()
|
||||
expect(registry.getVariantMode('gemini')).toBeNull()
|
||||
expect(registry.getVariantMode('azure-openai')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for unknown IDs', () => {
|
||||
expect(registry.getVariantMode('unknown')).toBeNull()
|
||||
expect(registry.getVariantMode('non-existent-variant')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getVariants', () => {
|
||||
it('should return variant IDs for providers with variants', () => {
|
||||
expect(registry.getVariants('openai')).toEqual(['openai-chat'])
|
||||
expect(registry.getVariants('azure')).toEqual(['azure-responses'])
|
||||
expect(registry.getVariants('google')).toEqual(['google-chat'])
|
||||
})
|
||||
|
||||
it('should return empty array for providers without variants', () => {
|
||||
expect(registry.getVariants('xai')).toEqual([])
|
||||
})
|
||||
|
||||
it('should support aliases in base ID', () => {
|
||||
expect(registry.getVariants('oai')).toEqual(['openai-chat'])
|
||||
expect(registry.getVariants('gemini')).toEqual(['google-chat'])
|
||||
expect(registry.getVariants('azure-openai')).toEqual(['azure-responses'])
|
||||
})
|
||||
|
||||
it('should return empty array for unknown IDs', () => {
|
||||
expect(registry.getVariants('unknown')).toEqual([])
|
||||
expect(registry.getVariants('non-existent')).toEqual([])
|
||||
})
|
||||
|
||||
it('should return all variants for providers with multiple variants', () => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'multi-variant',
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Chat',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
},
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Responses',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
},
|
||||
{
|
||||
suffix: 'completions',
|
||||
name: 'Completions',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
)
|
||||
|
||||
const variants = registry.getVariants('multi-variant')
|
||||
expect(variants).toHaveLength(3)
|
||||
expect(variants).toContain('multi-variant-chat')
|
||||
expect(variants).toContain('multi-variant-responses')
|
||||
expect(variants).toContain('multi-variant-completions')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Integration: All methods working together', () => {
|
||||
it('should provide consistent information about a variant', () => {
|
||||
const variantId = 'openai-chat'
|
||||
// isVariant should confirm it's a variant
|
||||
expect(registry.isVariant(variantId)).toBe(true)
|
||||
|
||||
// getBaseProviderId should extract base ID
|
||||
expect(registry.getBaseProviderId(variantId)).toBe('openai')
|
||||
|
||||
// getVariantMode should extract mode
|
||||
expect(registry.getVariantMode(variantId)).toBe('chat')
|
||||
// getVariants should include this variant when querying base ID
|
||||
const baseId = registry.getBaseProviderId(variantId)!
|
||||
expect(registry.getVariants(baseId)).toContain(variantId)
|
||||
})
|
||||
|
||||
it('should provide consistent information about a base provider', () => {
|
||||
const baseId = 'openai'
|
||||
|
||||
// isVariant should return false
|
||||
expect(registry.isVariant(baseId)).toBe(false)
|
||||
|
||||
// getBaseProviderId should return itself
|
||||
expect(registry.getBaseProviderId(baseId)).toBe(baseId)
|
||||
|
||||
// getVariantMode should return null
|
||||
expect(registry.getVariantMode(baseId)).toBeNull()
|
||||
|
||||
// getVariants should return its variants
|
||||
expect(registry.getVariants(baseId)).toEqual(['openai-chat'])
|
||||
})
|
||||
|
||||
it('should provide consistent information about an alias', () => {
|
||||
const aliasId = 'oai'
|
||||
|
||||
// isVariant should return false
|
||||
expect(registry.isVariant(aliasId)).toBe(false)
|
||||
|
||||
// getBaseProviderId should resolve to base ID
|
||||
expect(registry.getBaseProviderId(aliasId)).toBe('openai')
|
||||
|
||||
// getVariantMode should return null
|
||||
expect(registry.getVariantMode(aliasId)).toBeNull()
|
||||
|
||||
// getVariants should work with alias
|
||||
expect(registry.getVariants(aliasId)).toEqual(['openai-chat'])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTyped()', () => {
|
||||
it('should return typed extension for registered providers', () => {
|
||||
// Register extensions
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
)
|
||||
|
||||
const ext = registry.getTyped('openai')
|
||||
expect(ext).toBeDefined()
|
||||
expect(ext?.config.name).toBe('openai')
|
||||
})
|
||||
|
||||
it('should return undefined for unregistered providers', () => {
|
||||
const ext = registry.getTyped('unknown' as any)
|
||||
expect(ext).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should preserve type information (compile-time check)', () => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
)
|
||||
|
||||
// This test primarily validates compile-time type inference
|
||||
// Runtime behavior is the same as get()
|
||||
const ext = registry.getTyped('openai')
|
||||
expect(ext).toBeDefined()
|
||||
|
||||
// Type should be inferred as ProviderExtension<OpenAIProviderSettings, any, any>
|
||||
// but we can't test types at runtime, only compile-time
|
||||
})
|
||||
|
||||
it('should work with aliases', () => {
|
||||
registry.register(
|
||||
new ProviderExtension({
|
||||
name: 'openai',
|
||||
aliases: ['oai'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
)
|
||||
|
||||
const ext = registry.getTyped('oai' as any)
|
||||
expect(ext).toBeDefined()
|
||||
expect(ext?.config.name).toBe('openai')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,526 +0,0 @@
|
||||
/**
|
||||
* HubProvider Comprehensive Tests
|
||||
* Tests hub provider routing, model resolution, and error handling
|
||||
* Covers multi-provider routing with namespaced model IDs
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { createHubProvider, type HubProviderConfig, HubProviderError } from '../HubProvider'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../RegistryManagement'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
getProvider: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('ai', () => ({
|
||||
customProvider: vi.fn((config) => config.fallbackProvider),
|
||||
wrapProvider: vi.fn((config) => config.provider),
|
||||
jsonSchema: vi.fn((schema) => schema)
|
||||
}))
|
||||
|
||||
describe('HubProvider', () => {
|
||||
let mockOpenAIProvider: ProviderV3
|
||||
let mockAnthropicProvider: ProviderV3
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
mockEmbeddingModel = createMockEmbeddingModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-embedding'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Create mock providers
|
||||
mockOpenAIProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel)
|
||||
} as ProviderV3
|
||||
|
||||
mockAnthropicProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel)
|
||||
} as ProviderV3
|
||||
|
||||
// Setup default mock implementation
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockImplementation((id) => {
|
||||
if (id === 'openai') return mockOpenAIProvider
|
||||
if (id === 'anthropic') return mockAnthropicProvider
|
||||
return undefined
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Creation', () => {
|
||||
it('should create hub provider with basic config', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'test-hub'
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
expect(customProvider).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should create provider with debug flag', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'test-hub',
|
||||
debug: true
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return ProviderV3 specification', () => {
|
||||
const config: HubProviderConfig = {
|
||||
hubId: 'aihubmix'
|
||||
}
|
||||
|
||||
const provider = createHubProvider(config)
|
||||
|
||||
expect(provider).toHaveProperty('specificationVersion', 'v3')
|
||||
expect(provider).toHaveProperty('languageModel')
|
||||
expect(provider).toHaveProperty('embeddingModel')
|
||||
expect(provider).toHaveProperty('imageModel')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model ID Parsing', () => {
|
||||
it('should parse valid hub model ID format', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4`
|
||||
|
||||
const result = provider.languageModel(modelId)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should throw error for invalid model ID format', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const invalidId = 'invalid-id-without-separator'
|
||||
|
||||
expect(() => provider.languageModel(invalidId)).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for model ID with multiple separators', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const multiSeparatorId = `provider${DEFAULT_SEPARATOR}extra${DEFAULT_SEPARATOR}model`
|
||||
|
||||
expect(() => provider.languageModel(multiSeparatorId)).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for empty model ID', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel('')).toThrow(HubProviderError)
|
||||
})
|
||||
|
||||
it('should throw error for model ID with only separator', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(DEFAULT_SEPARATOR)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Language Model Resolution', () => {
|
||||
it('should route to correct provider for language model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should route different providers correctly', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('anthropic')
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledWith('claude-3')
|
||||
})
|
||||
|
||||
it('should wrap provider with wrapProvider', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledWith({
|
||||
provider: mockOpenAIProvider,
|
||||
languageModelMiddleware: []
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw HubProviderError if provider not initialized', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
|
||||
expect(() => provider.languageModel(`uninitialized${DEFAULT_SEPARATOR}model`)).toThrow(/not initialized/)
|
||||
})
|
||||
|
||||
it('should include provider ID in error message', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
try {
|
||||
provider.languageModel(`missing${DEFAULT_SEPARATOR}model`)
|
||||
expect.fail('Should have thrown HubProviderError')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(HubProviderError)
|
||||
const hubError = error as HubProviderError
|
||||
expect(hubError.providerId).toBe('missing')
|
||||
expect(hubError.hubId).toBe('test-hub')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Embedding Model Resolution', () => {
|
||||
it('should route to correct provider for embedding model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}text-embedding-3-small`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('text-embedding-3-small')
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should handle different embedding providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
|
||||
provider.embeddingModel(`anthropic${DEFAULT_SEPARATOR}embed-v1`)
|
||||
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
|
||||
expect(mockAnthropicProvider.embeddingModel).toHaveBeenCalledWith('embed-v1')
|
||||
})
|
||||
|
||||
it('should throw error for uninitialized embedding provider', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(undefined)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.embeddingModel(`missing${DEFAULT_SEPARATOR}embed`)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Image Model Resolution', () => {
|
||||
it('should route to correct provider for image model', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledWith('openai')
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should handle different image providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
provider.imageModel(`anthropic${DEFAULT_SEPARATOR}image-gen`)
|
||||
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
expect(mockAnthropicProvider.imageModel).toHaveBeenCalledWith('image-gen')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Special Model Types', () => {
|
||||
it('should support transcription models', () => {
|
||||
const mockTranscriptionModel = {
|
||||
specificationVersion: 'v3',
|
||||
doTranscribe: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithTranscription = {
|
||||
...mockOpenAIProvider,
|
||||
transcriptionModel: vi.fn().mockReturnValue(mockTranscriptionModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithTranscription)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper-1`)
|
||||
|
||||
expect(providerWithTranscription.transcriptionModel).toHaveBeenCalledWith('whisper-1')
|
||||
expect(result).toBe(mockTranscriptionModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support transcription', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(HubProviderError)
|
||||
expect(() => provider.transcriptionModel!(`openai${DEFAULT_SEPARATOR}whisper`)).toThrow(
|
||||
/does not support transcription/
|
||||
)
|
||||
})
|
||||
|
||||
it('should support speech models', () => {
|
||||
const mockSpeechModel = {
|
||||
specificationVersion: 'v3',
|
||||
doGenerate: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithSpeech = {
|
||||
...mockOpenAIProvider,
|
||||
speechModel: vi.fn().mockReturnValue(mockSpeechModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithSpeech)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)
|
||||
|
||||
expect(providerWithSpeech.speechModel).toHaveBeenCalledWith('tts-1')
|
||||
expect(result).toBe(mockSpeechModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support speech', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(HubProviderError)
|
||||
expect(() => provider.speechModel!(`openai${DEFAULT_SEPARATOR}tts-1`)).toThrow(/does not support speech/)
|
||||
})
|
||||
|
||||
it('should support reranking models', () => {
|
||||
const mockRerankingModel = {
|
||||
specificationVersion: 'v3',
|
||||
doRerank: vi.fn()
|
||||
}
|
||||
|
||||
const providerWithReranking = {
|
||||
...mockOpenAIProvider,
|
||||
rerankingModel: vi.fn().mockReturnValue(mockRerankingModel)
|
||||
} as ProviderV3
|
||||
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(providerWithReranking)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank-v1`)
|
||||
|
||||
expect(providerWithReranking.rerankingModel).toHaveBeenCalledWith('rerank-v1')
|
||||
expect(result).toBe(mockRerankingModel)
|
||||
})
|
||||
|
||||
it('should throw error if provider does not support reranking', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(HubProviderError)
|
||||
expect(() => provider.rerankingModel!(`openai${DEFAULT_SEPARATOR}rerank`)).toThrow(/does not support reranking/)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should create HubProviderError with all properties', () => {
|
||||
const originalError = new Error('Original error')
|
||||
const error = new HubProviderError('Test message', 'test-hub', 'test-provider', originalError)
|
||||
|
||||
expect(error.message).toBe('Test message')
|
||||
expect(error.hubId).toBe('test-hub')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.originalError).toBe(originalError)
|
||||
expect(error.name).toBe('HubProviderError')
|
||||
})
|
||||
|
||||
it('should create HubProviderError without optional parameters', () => {
|
||||
const error = new HubProviderError('Test message', 'test-hub')
|
||||
|
||||
expect(error.message).toBe('Test message')
|
||||
expect(error.hubId).toBe('test-hub')
|
||||
expect(error.providerId).toBeUndefined()
|
||||
expect(error.originalError).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should wrap provider errors in HubProviderError', () => {
|
||||
const providerError = new Error('Provider failed')
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockImplementation(() => {
|
||||
throw providerError
|
||||
})
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
try {
|
||||
provider.languageModel(`failing${DEFAULT_SEPARATOR}model`)
|
||||
expect.fail('Should have thrown HubProviderError')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(HubProviderError)
|
||||
const hubError = error as HubProviderError
|
||||
expect(hubError.originalError).toBe(providerError)
|
||||
expect(hubError.message).toContain('Failed to get provider')
|
||||
}
|
||||
})
|
||||
|
||||
it('should handle null provider from registry', () => {
|
||||
vi.mocked(globalRegistryManagement.getProvider).mockReturnValue(null as any)
|
||||
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
expect(() => provider.languageModel(`null-provider${DEFAULT_SEPARATOR}model`)).toThrow(HubProviderError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multi-Provider Scenarios', () => {
|
||||
it('should handle sequential calls to different providers', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`anthropic${DEFAULT_SEPARATOR}claude-3`)
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledTimes(2)
|
||||
expect(mockAnthropicProvider.languageModel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle mixed model types from same provider', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada-002`)
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dall-e-3`)
|
||||
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(3)
|
||||
expect(mockOpenAIProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockOpenAIProvider.embeddingModel).toHaveBeenCalledWith('ada-002')
|
||||
expect(mockOpenAIProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
})
|
||||
|
||||
it('should cache provider lookups', () => {
|
||||
const config: HubProviderConfig = { hubId: 'aihubmix' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-3.5`)
|
||||
|
||||
// Should call getProvider twice (once per model call)
|
||||
expect(globalRegistryManagement.getProvider).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Wrapping', () => {
|
||||
it('should wrap all providers with empty middleware', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledWith({
|
||||
provider: mockOpenAIProvider,
|
||||
languageModelMiddleware: []
|
||||
})
|
||||
})
|
||||
|
||||
it('should wrap providers for all model types', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
|
||||
provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
|
||||
|
||||
expect(wrapProvider).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should return properly typed LanguageModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.languageModel(`openai${DEFAULT_SEPARATOR}gpt-4`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
expect(result).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.embeddingModel(`openai${DEFAULT_SEPARATOR}ada`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', () => {
|
||||
const config: HubProviderConfig = { hubId: 'test-hub' }
|
||||
const provider = createHubProvider(config) as ProviderV3
|
||||
|
||||
const result = provider.imageModel(`openai${DEFAULT_SEPARATOR}dalle`)
|
||||
|
||||
expect(result.specificationVersion).toBe('v3')
|
||||
expect(result).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,805 @@
|
||||
/**
|
||||
* ProviderExtension 单元测试
|
||||
*/
|
||||
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createMockProviderV3 } from '@test-utils'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ProviderExtension } from '../core/ProviderExtension'
|
||||
|
||||
describe('ProviderExtension', () => {
|
||||
describe('Static create() Method', () => {
|
||||
it('should create extension with config object', () => {
|
||||
const extension = ProviderExtension.create({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension).toBeInstanceOf(ProviderExtension)
|
||||
expect(extension.config.name).toBe('test-provider')
|
||||
})
|
||||
|
||||
it('should create extension with config function', () => {
|
||||
const configFn = vi.fn(() => ({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3,
|
||||
defaultOptions: { apiKey: 'test-key' }
|
||||
}))
|
||||
|
||||
const extension = ProviderExtension.create(configFn)
|
||||
|
||||
expect(configFn).toHaveBeenCalledOnce()
|
||||
expect(extension).toBeInstanceOf(ProviderExtension)
|
||||
expect(extension.config.name).toBe('test-provider')
|
||||
expect(extension.config.defaultOptions).toEqual({ apiKey: 'test-key' })
|
||||
})
|
||||
|
||||
it('should support type inference with generics', () => {
|
||||
interface TestSettings {
|
||||
apiKey: string
|
||||
baseURL?: string
|
||||
name: string
|
||||
}
|
||||
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any, // Type assertion needed as mock has different signature
|
||||
defaultOptions: {
|
||||
apiKey: 'test-key'
|
||||
}
|
||||
})
|
||||
|
||||
expect(extension.config.name).toBe('test-provider')
|
||||
})
|
||||
|
||||
it('should allow delayed config resolution with function', () => {
|
||||
let envVariable = 'initial-key'
|
||||
|
||||
const extension = ProviderExtension.create(() => ({
|
||||
name: 'dynamic-provider',
|
||||
create: createMockProviderV3,
|
||||
defaultOptions: {
|
||||
apiKey: envVariable // Captured at creation time
|
||||
}
|
||||
}))
|
||||
|
||||
expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' })
|
||||
|
||||
// Changing variable doesn't affect already created extension
|
||||
envVariable = 'changed-key'
|
||||
expect(extension.config.defaultOptions).toEqual({ apiKey: 'initial-key' })
|
||||
})
|
||||
|
||||
it('should validate config from function same as from object', async () => {
|
||||
expect(() => {
|
||||
ProviderExtension.create(() => ({
|
||||
name: '', // Invalid
|
||||
create: createMockProviderV3
|
||||
}))
|
||||
}).toThrow('name is required')
|
||||
|
||||
// Note: create/import validation happens at runtime in createProvider(), not in constructor
|
||||
// Extension can be created without create/import, but createProvider() will throw
|
||||
const extension = ProviderExtension.create(
|
||||
() =>
|
||||
({
|
||||
name: 'test-provider'
|
||||
// Missing create
|
||||
}) as any
|
||||
)
|
||||
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Constructor Validation', () => {
|
||||
it('should throw error if name is missing', () => {
|
||||
expect(() => {
|
||||
new ProviderExtension({
|
||||
name: '',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
}).toThrow('name is required')
|
||||
})
|
||||
|
||||
it('should throw error at runtime if neither create nor import is provided', async () => {
|
||||
// Constructor doesn't validate create/import - validation happens at runtime
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider'
|
||||
} as any)
|
||||
|
||||
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
|
||||
})
|
||||
|
||||
it('should throw error at runtime if import is provided without creatorFunctionName', async () => {
|
||||
// Constructor doesn't validate creatorFunctionName - validation happens at runtime
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
import: async () => ({})
|
||||
} as any)
|
||||
|
||||
await expect(extension.createProvider()).rejects.toThrow('cannot create provider')
|
||||
})
|
||||
|
||||
it('should create extension with valid config', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.config.name).toBe('test-provider')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Configure Method', () => {
|
||||
it('should return new instance with merged settings', () => {
|
||||
const original = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
defaultOptions: { apiKey: 'original-key' }
|
||||
})
|
||||
|
||||
const configured = original.configure({ baseURL: 'https://api.test.com' })
|
||||
|
||||
// 原实例不变
|
||||
expect(original.config.defaultOptions).toEqual({ apiKey: 'original-key' })
|
||||
|
||||
// 新实例合并配置
|
||||
expect(configured.config.defaultOptions).toEqual({
|
||||
apiKey: 'original-key',
|
||||
baseURL: 'https://api.test.com'
|
||||
})
|
||||
|
||||
// 是新实例
|
||||
expect(configured).not.toBe(original)
|
||||
})
|
||||
|
||||
it('should override existing options', () => {
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
defaultOptions: { apiKey: 'old-key', timeout: 5000 }
|
||||
})
|
||||
|
||||
const configured = extension.configure({ apiKey: 'new-key' })
|
||||
|
||||
expect(configured.config.defaultOptions).toEqual({
|
||||
apiKey: 'new-key',
|
||||
timeout: 5000
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getProviderIds', () => {
|
||||
it('should return only main ID when no aliases or variants', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.getProviderIds()).toEqual(['openai'])
|
||||
})
|
||||
|
||||
it('should include aliases', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openrouter',
|
||||
aliases: ['or', 'open-router'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.getProviderIds()).toEqual(['openrouter', 'or', 'open-router'])
|
||||
})
|
||||
|
||||
it('should include variant IDs', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
transform: (provider) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
expect(extension.getProviderIds()).toEqual(['openai', 'openai-chat'])
|
||||
})
|
||||
|
||||
it('should include both aliases and variant IDs', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'azure',
|
||||
aliases: ['az'],
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Azure Chat',
|
||||
transform: (provider) => provider
|
||||
},
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Azure Responses',
|
||||
transform: (provider) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
expect(extension.getProviderIds()).toEqual(['azure', 'az', 'azure-chat', 'azure-responses'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasProviderId', () => {
|
||||
it('should return true for main ID', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.hasProviderId('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for alias', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openrouter',
|
||||
aliases: ['or'],
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.hasProviderId('or')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for variant ID', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Chat',
|
||||
transform: (provider) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
expect(extension.hasProviderId('openai-chat')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-existent ID', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'openai',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.hasProviderId('anthropic')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getVariant', () => {
|
||||
it('should return variant by suffix', () => {
|
||||
const chatVariant = {
|
||||
suffix: 'chat',
|
||||
name: 'Chat Mode',
|
||||
transform: (provider: ProviderV3) => provider
|
||||
}
|
||||
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test',
|
||||
create: createMockProviderV3,
|
||||
variants: [chatVariant]
|
||||
})
|
||||
|
||||
expect(extension.getVariant('chat')).toEqual(chatVariant)
|
||||
})
|
||||
|
||||
it('should return undefined for non-existent variant', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test',
|
||||
create: createMockProviderV3,
|
||||
variants: []
|
||||
})
|
||||
|
||||
expect(extension.getVariant('chat')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined when no variants configured', () => {
|
||||
const extension = new ProviderExtension({
|
||||
name: 'test',
|
||||
create: createMockProviderV3
|
||||
})
|
||||
|
||||
expect(extension.getVariant('chat')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
interface TestSettings {
|
||||
apiKey: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
}
|
||||
|
||||
it('should maintain type safety with generics', () => {
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'typed-provider',
|
||||
create: ((settings: any) => {
|
||||
// TypeScript should infer settings as TestSettings
|
||||
expect(settings?.apiKey).toBeDefined()
|
||||
return createMockProviderV3()
|
||||
}) as any,
|
||||
defaultOptions: {
|
||||
apiKey: 'test-key',
|
||||
timeout: 5000
|
||||
}
|
||||
})
|
||||
|
||||
const configured = extension.configure({
|
||||
baseURL: 'https://api.test.com'
|
||||
// TypeScript should catch invalid properties here
|
||||
})
|
||||
|
||||
expect(configured.config.defaultOptions?.baseURL).toBe('https://api.test.com')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Options Getter', () => {
|
||||
it('should return readonly frozen options', () => {
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
defaultOptions: { apiKey: 'test-key', timeout: 5000 }
|
||||
})
|
||||
|
||||
const options = extension.options
|
||||
|
||||
expect(options).toEqual({ apiKey: 'test-key', timeout: 5000 })
|
||||
expect(Object.isFrozen(options)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Deep Merge in Configure', () => {
|
||||
it('should deep merge nested objects', () => {
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
defaultOptions: {
|
||||
apiKey: 'key1',
|
||||
headers: {
|
||||
'X-Custom-Header': 'value1',
|
||||
Authorization: 'Bearer token1'
|
||||
},
|
||||
retry: {
|
||||
maxAttempts: 3,
|
||||
backoff: 1000
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const configured = extension.configure({
|
||||
headers: {
|
||||
Authorization: 'Bearer new-token'
|
||||
},
|
||||
retry: {
|
||||
maxAttempts: 5
|
||||
}
|
||||
})
|
||||
|
||||
expect(configured.config.defaultOptions).toEqual({
|
||||
apiKey: 'key1',
|
||||
headers: {
|
||||
'X-Custom-Header': 'value1',
|
||||
Authorization: 'Bearer new-token'
|
||||
},
|
||||
retry: {
|
||||
maxAttempts: 5,
|
||||
backoff: 1000
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should not mutate original extension', () => {
|
||||
const original = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
defaultOptions: {
|
||||
nested: { value: 'original' }
|
||||
}
|
||||
})
|
||||
|
||||
const configured = original.configure({
|
||||
nested: { value: 'modified' }
|
||||
})
|
||||
|
||||
expect(original.config.defaultOptions).toEqual({
|
||||
nested: { value: 'original' }
|
||||
})
|
||||
expect(configured.config.defaultOptions).toEqual({
|
||||
nested: { value: 'modified' }
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Instance Caching (Phase 1)', () => {
|
||||
interface TestSettings {
|
||||
apiKey: string
|
||||
baseURL?: string
|
||||
}
|
||||
|
||||
it('should cache and reuse instance with same settings', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings = { apiKey: 'test-key', baseURL: 'https://api.test.com' }
|
||||
|
||||
const instance1 = await extension.createProvider(settings)
|
||||
const instance2 = await extension.createProvider(settings)
|
||||
|
||||
// Should return the same instance
|
||||
expect(instance1).toBe(instance2)
|
||||
// Should only call create once
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should create new instance with different settings', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings1 = { apiKey: 'key1' }
|
||||
const settings2 = { apiKey: 'key2' }
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
// Should return different instances
|
||||
expect(instance1).not.toBe(instance2)
|
||||
// Should call create twice
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should handle undefined settings correctly', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const instance1 = await extension.createProvider()
|
||||
const instance2 = await extension.createProvider()
|
||||
const instance3 = await extension.createProvider(undefined)
|
||||
|
||||
// All should be the same instance (undefined settings)
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(instance1).toBe(instance3)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should compute stable hash for same settings in different order', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
// Same settings but different property order
|
||||
const settings1 = { apiKey: 'key', baseURL: 'url', timeout: 5000 }
|
||||
const settings2 = { timeout: 5000, apiKey: 'key', baseURL: 'url' }
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
// Should recognize as same settings
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should merge with default options before hashing', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any,
|
||||
defaultOptions: {
|
||||
apiKey: 'default-key',
|
||||
timeout: 5000
|
||||
}
|
||||
})
|
||||
|
||||
const instance1 = await extension.createProvider({ baseURL: 'url' })
|
||||
const instance2 = await extension.createProvider({ baseURL: 'url', apiKey: 'default-key', timeout: 5000 })
|
||||
|
||||
// Should be same after merging with defaults
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle nested objects in settings', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings1 = {
|
||||
apiKey: 'key',
|
||||
headers: { Authorization: 'Bearer token', 'X-Custom': 'value' }
|
||||
}
|
||||
const settings2 = {
|
||||
apiKey: 'key',
|
||||
headers: { 'X-Custom': 'value', Authorization: 'Bearer token' }
|
||||
}
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
// Should recognize as same (order doesn't matter)
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should support variant suffix parameter', async () => {
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createMockProviderV3 as any,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'Test Chat',
|
||||
transform: (provider) => provider
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const settings = { apiKey: 'test-key' }
|
||||
|
||||
// Should work when providing a valid variant suffix
|
||||
await expect(extension.createProvider(settings, 'chat')).resolves.toBeDefined()
|
||||
|
||||
// Should throw for unknown variant suffix
|
||||
await expect(extension.createProvider(settings, 'unknown')).rejects.toThrow('variant "unknown" not found')
|
||||
})
|
||||
|
||||
it('should support dynamic import providers', async () => {
|
||||
const mockModule = {
|
||||
createProvider: vi.fn(createMockProviderV3)
|
||||
}
|
||||
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'lazy-provider',
|
||||
import: async () => mockModule,
|
||||
creatorFunctionName: 'createProvider'
|
||||
})
|
||||
|
||||
const instance1 = await extension.createProvider({ apiKey: 'key' })
|
||||
const instance2 = await extension.createProvider({ apiKey: 'key' })
|
||||
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(mockModule.createProvider).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should throw error if creatorFunctionName not found in module', async () => {
|
||||
const mockModule = {
|
||||
wrongName: vi.fn(createMockProviderV3)
|
||||
}
|
||||
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'lazy-provider',
|
||||
import: async () => mockModule,
|
||||
creatorFunctionName: 'createProvider'
|
||||
})
|
||||
|
||||
await expect(extension.createProvider({ apiKey: 'key' })).rejects.toThrow(
|
||||
'creatorFunctionName "createProvider" not found'
|
||||
)
|
||||
})
|
||||
|
||||
it('should deduplicate concurrent requests with same settings', async () => {
|
||||
const createFn = vi.fn(async () => {
|
||||
// Simulate async delay
|
||||
await new Promise((resolve) => setTimeout(resolve, 10))
|
||||
return createMockProviderV3()
|
||||
})
|
||||
|
||||
const extension = new ProviderExtension<TestSettings>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings = { apiKey: 'test-key' }
|
||||
|
||||
// Fire multiple concurrent requests
|
||||
const [instance1, instance2, instance3] = await Promise.all([
|
||||
extension.createProvider(settings),
|
||||
extension.createProvider(settings),
|
||||
extension.createProvider(settings)
|
||||
])
|
||||
|
||||
// All concurrent requests should return the same instance
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(instance2).toBe(instance3)
|
||||
// Creator should only be called once
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Verify subsequent sequential calls also use cache
|
||||
const instance4 = await extension.createProvider(settings)
|
||||
expect(instance4).toBe(instance1)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle arrays in settings', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings1 = { apiKey: 'key', tags: ['a', 'b', 'c'] }
|
||||
const settings2 = { apiKey: 'key', tags: ['a', 'b', 'c'] }
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should differentiate settings with different array values', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const settings1 = { apiKey: 'key', tags: ['a', 'b'] }
|
||||
const settings2 = { apiKey: 'key', tags: ['a', 'c'] }
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
expect(instance1).not.toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cache Key Correctness (no hash collisions)', () => {
|
||||
it('should differentiate settings with different API keys', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const instance1 = await extension.createProvider({ apiKey: 'sk-key-A', baseURL: 'https://api.example.com' })
|
||||
const instance2 = await extension.createProvider({ apiKey: 'sk-key-B', baseURL: 'https://api.example.com' })
|
||||
|
||||
expect(instance1).not.toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should differentiate settings with different base URLs', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const instance1 = await extension.createProvider({ apiKey: 'same-key', baseURL: 'https://api-a.example.com' })
|
||||
const instance2 = await extension.createProvider({ apiKey: 'same-key', baseURL: 'https://api-b.example.com' })
|
||||
|
||||
expect(instance1).not.toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should treat structurally identical settings as the same regardless of construction', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
// Construct settings in completely different ways
|
||||
const base = { apiKey: 'key', baseURL: 'url' }
|
||||
const settings1 = { ...base, headers: { Authorization: 'Bearer tok' } }
|
||||
const settings2 = JSON.parse(JSON.stringify(settings1))
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should differentiate same-base-provider with different variant suffixes', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'azure',
|
||||
create: createFn as any,
|
||||
variants: [
|
||||
{ suffix: 'chat', name: 'Chat', transform: (p: any) => ({ ...p, _variant: 'chat' }) },
|
||||
{ suffix: 'responses', name: 'Responses', transform: (p: any) => ({ ...p, _variant: 'responses' }) }
|
||||
]
|
||||
})
|
||||
|
||||
const settings = { apiKey: 'key' }
|
||||
|
||||
const chatInstance = await extension.createProvider(settings, 'chat')
|
||||
const responsesInstance = await extension.createProvider(settings, 'responses')
|
||||
const baseInstance = await extension.createProvider(settings)
|
||||
|
||||
// Variant instances should be different from each other
|
||||
expect(chatInstance).not.toBe(responsesInstance)
|
||||
// Base provider is cached when first variant is created, so baseInstance
|
||||
// is the same object as the unwrapped base (reused across variants)
|
||||
expect(chatInstance).not.toBe(baseInstance)
|
||||
expect(responsesInstance).not.toBe(baseInstance)
|
||||
// createFn called once for 'chat' variant (also caches base), once for 'responses' (reuses cached base)
|
||||
// base provider request reuses the cached instance from the first variant creation
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should handle settings with functions by treating them uniformly', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
const fetchFn = () => Promise.resolve(new Response())
|
||||
const settings1 = { apiKey: 'key', fetch: fetchFn }
|
||||
const settings2 = { apiKey: 'key', fetch: fetchFn }
|
||||
|
||||
const instance1 = await extension.createProvider(settings1)
|
||||
const instance2 = await extension.createProvider(settings2)
|
||||
|
||||
// Same function reference → same serialization → same cache hit
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should distinguish settings with null vs missing keys', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
// null/undefined serialize identically via stableStringify, so same cache key
|
||||
const instance1 = await extension.createProvider({ apiKey: 'key', extra: null })
|
||||
const instance2 = await extension.createProvider({ apiKey: 'key', extra: null })
|
||||
|
||||
expect(instance1).toBe(instance2)
|
||||
expect(createFn).toHaveBeenCalledTimes(1)
|
||||
|
||||
// But a truly different value should create a new instance
|
||||
const instance3 = await extension.createProvider({ apiKey: 'key', extra: 'value' })
|
||||
expect(instance3).not.toBe(instance1)
|
||||
expect(createFn).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should not collide on similarly-structured but different settings', async () => {
|
||||
const createFn = vi.fn(createMockProviderV3)
|
||||
const extension = new ProviderExtension<any>({
|
||||
name: 'test-provider',
|
||||
create: createFn as any
|
||||
})
|
||||
|
||||
// These have similar structure but different values - old DJB2 hash could collide
|
||||
const settingsA = { apiKey: 'aaaa1111', baseURL: 'https://host-a.com', timeout: 3000 }
|
||||
const settingsB = { apiKey: 'bbbb2222', baseURL: 'https://host-b.com', timeout: 3000 }
|
||||
const settingsC = { apiKey: 'cccc3333', baseURL: 'https://host-c.com', timeout: 3000 }
|
||||
|
||||
const instanceA = await extension.createProvider(settingsA)
|
||||
const instanceB = await extension.createProvider(settingsB)
|
||||
const instanceC = await extension.createProvider(settingsC)
|
||||
|
||||
expect(instanceA).not.toBe(instanceB)
|
||||
expect(instanceB).not.toBe(instanceC)
|
||||
expect(instanceA).not.toBe(instanceC)
|
||||
expect(createFn).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,562 +0,0 @@
|
||||
/**
|
||||
* RegistryManagement Comprehensive Tests
|
||||
* Tests provider registry management, model resolution, and alias handling
|
||||
* Covers registration, retrieval, and cleanup operations
|
||||
*/
|
||||
|
||||
import type { EmbeddingModelV3, ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockEmbeddingModel, createMockImageModel, createMockLanguageModel } from '../../../__tests__'
|
||||
import { DEFAULT_SEPARATOR, RegistryManagement } from '../RegistryManagement'
|
||||
|
||||
// Mock AI SDK
|
||||
vi.mock('ai', () => ({
|
||||
createProviderRegistry: vi.fn(),
|
||||
jsonSchema: vi.fn((schema) => schema)
|
||||
}))
|
||||
|
||||
describe('RegistryManagement', () => {
|
||||
let registry: RegistryManagement
|
||||
let mockProvider: ProviderV3
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockEmbeddingModel: EmbeddingModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create mock models using global utilities
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
mockEmbeddingModel = createMockEmbeddingModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-embedding'
|
||||
})
|
||||
|
||||
mockImageModel = createMockImageModel({
|
||||
provider: 'test',
|
||||
modelId: 'test-image'
|
||||
})
|
||||
|
||||
// Create mock provider
|
||||
mockProvider = {
|
||||
specificationVersion: 'v3',
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel),
|
||||
transcriptionModel: vi.fn(),
|
||||
speechModel: vi.fn()
|
||||
} as ProviderV3
|
||||
|
||||
// Setup mock registry
|
||||
const mockRegistry = {
|
||||
languageModel: vi.fn().mockReturnValue(mockLanguageModel),
|
||||
embeddingModel: vi.fn().mockReturnValue(mockEmbeddingModel),
|
||||
imageModel: vi.fn().mockReturnValue(mockImageModel),
|
||||
transcriptionModel: vi.fn(),
|
||||
speechModel: vi.fn()
|
||||
}
|
||||
|
||||
vi.mocked(createProviderRegistry).mockReturnValue(mockRegistry as any)
|
||||
|
||||
registry = new RegistryManagement()
|
||||
})
|
||||
|
||||
describe('Constructor and Initialization', () => {
|
||||
it('should create registry with default separator', () => {
|
||||
const reg = new RegistryManagement()
|
||||
|
||||
expect(reg).toBeInstanceOf(RegistryManagement)
|
||||
expect(reg.hasProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('should create registry with custom separator', () => {
|
||||
const customSeparator = ':'
|
||||
const reg = new RegistryManagement({ separator: customSeparator })
|
||||
|
||||
expect(reg).toBeInstanceOf(RegistryManagement)
|
||||
})
|
||||
|
||||
it('should start with empty provider list', () => {
|
||||
expect(registry.getRegisteredProviders()).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Registration', () => {
|
||||
it('should register a provider', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.hasProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('should register multiple providers', () => {
|
||||
const provider2 = { ...mockProvider }
|
||||
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', provider2)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('anthropic')).toBe(provider2)
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
const result = registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
|
||||
it('should rebuild registry after registration', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
openai: mockProvider
|
||||
}),
|
||||
{ separator: DEFAULT_SEPARATOR }
|
||||
)
|
||||
})
|
||||
|
||||
it('should register provider with aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('gpt')).toBe(mockProvider)
|
||||
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should track aliases separately', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
expect(registry.isAlias('gpt')).toBe(true)
|
||||
expect(registry.isAlias('openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle multiple aliases for same provider', () => {
|
||||
const aliases = ['alias1', 'alias2', 'alias3']
|
||||
registry.registerProvider('provider', mockProvider, aliases)
|
||||
|
||||
aliases.forEach((alias) => {
|
||||
expect(registry.getProvider(alias)).toBe(mockProvider)
|
||||
expect(registry.isAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Bulk Registration', () => {
|
||||
it('should register multiple providers at once', () => {
|
||||
const providers = {
|
||||
openai: mockProvider,
|
||||
anthropic: { ...mockProvider },
|
||||
google: { ...mockProvider }
|
||||
}
|
||||
|
||||
registry.registerProviders(providers)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(providers.openai)
|
||||
expect(registry.getProvider('anthropic')).toBe(providers.anthropic)
|
||||
expect(registry.getProvider('google')).toBe(providers.google)
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
const result = registry.registerProviders({ openai: mockProvider })
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Retrieval', () => {
|
||||
beforeEach(() => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
})
|
||||
|
||||
it('should retrieve registered provider', () => {
|
||||
const provider = registry.getProvider('openai')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should return undefined for unregistered provider', () => {
|
||||
const provider = registry.getProvider('nonexistent')
|
||||
|
||||
expect(provider).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should retrieve provider by alias', () => {
|
||||
registry.registerProvider('anthropic', mockProvider, ['claude'])
|
||||
|
||||
const provider = registry.getProvider('claude')
|
||||
|
||||
expect(provider).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should get list of all registered providers', () => {
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
registry.registerProvider('google', mockProvider, ['gemini'])
|
||||
|
||||
const providers = registry.getRegisteredProviders()
|
||||
|
||||
expect(providers).toContain('openai')
|
||||
expect(providers).toContain('anthropic')
|
||||
expect(providers).toContain('google')
|
||||
expect(providers).toContain('gemini') // Aliases included
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider Unregistration', () => {
|
||||
it('should unregister provider', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should unregister provider with all its aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
expect(registry.getProvider('gpt')).toBeUndefined()
|
||||
expect(registry.getProvider('chatgpt')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should unregister only alias when alias is removed', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('gpt')
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(mockProvider)
|
||||
expect(registry.getProvider('gpt')).toBeUndefined()
|
||||
expect(registry.getProvider('chatgpt')).toBe(mockProvider)
|
||||
})
|
||||
|
||||
it('should handle unregistering non-existent provider', () => {
|
||||
expect(() => registry.unregisterProvider('nonexistent')).not.toThrow()
|
||||
})
|
||||
|
||||
it('should return this for chaining', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const result = registry.unregisterProvider('openai')
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
|
||||
it('should rebuild registry after unregistration', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
vi.clearAllMocks()
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
// Should rebuild with empty providers
|
||||
expect(createProviderRegistry).not.toHaveBeenCalled() // No rebuild when empty
|
||||
})
|
||||
})
|
||||
|
||||
describe('Model Resolution', () => {
|
||||
beforeEach(() => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
})
|
||||
|
||||
it('should resolve language model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}gpt-4` as any
|
||||
|
||||
const result = registry.languageModel(modelId)
|
||||
|
||||
expect(result).toBe(mockLanguageModel)
|
||||
})
|
||||
|
||||
it('should resolve embedding model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}text-embedding-3-small` as any
|
||||
|
||||
const result = registry.embeddingModel(modelId)
|
||||
|
||||
expect(result).toBe(mockEmbeddingModel)
|
||||
})
|
||||
|
||||
it('should resolve image model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}dall-e-3` as any
|
||||
|
||||
const result = registry.imageModel(modelId)
|
||||
|
||||
expect(result).toBe(mockImageModel)
|
||||
})
|
||||
|
||||
it('should resolve transcription model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}whisper-1` as any
|
||||
|
||||
registry.transcriptionModel(modelId)
|
||||
|
||||
// Verify it calls through to the mock registry
|
||||
expect(createProviderRegistry).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should resolve speech model', () => {
|
||||
const modelId = `openai${DEFAULT_SEPARATOR}tts-1` as any
|
||||
|
||||
registry.speechModel(modelId)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should throw error when no providers registered', () => {
|
||||
const emptyRegistry = new RegistryManagement()
|
||||
|
||||
expect(() => emptyRegistry.languageModel('openai|gpt-4' as any)).toThrow('No providers registered')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Alias Management', () => {
|
||||
it('should resolve provider ID from alias', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
const realId = registry.resolveProviderId('gpt')
|
||||
|
||||
expect(realId).toBe('openai')
|
||||
})
|
||||
|
||||
it('should return same ID if not an alias', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const realId = registry.resolveProviderId('openai')
|
||||
|
||||
expect(realId).toBe('openai')
|
||||
})
|
||||
|
||||
it('should check if ID is alias', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
|
||||
expect(registry.isAlias('gpt')).toBe(true)
|
||||
expect(registry.isAlias('openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should get all alias mappings', () => {
|
||||
const provider2 = { ...mockProvider }
|
||||
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
registry.registerProvider('anthropic', provider2, ['claude'])
|
||||
|
||||
const aliases = registry.getAllAliases()
|
||||
|
||||
// Check that all aliases are present
|
||||
expect(aliases['gpt']).toBe('openai')
|
||||
expect(aliases['chatgpt']).toBe('openai')
|
||||
expect(aliases['claude']).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('should return empty object when no aliases', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const aliases = registry.getAllAliases()
|
||||
|
||||
expect(aliases).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry State', () => {
|
||||
it('should check if has providers', () => {
|
||||
expect(registry.hasProviders()).toBe(false)
|
||||
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(registry.hasProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('should clear all providers', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt'])
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
registry.clear()
|
||||
|
||||
expect(registry.hasProviders()).toBe(false)
|
||||
expect(registry.getRegisteredProviders()).toEqual([])
|
||||
expect(registry.getAllAliases()).toEqual({})
|
||||
})
|
||||
|
||||
it('should return this after clear for chaining', () => {
|
||||
const result = registry.clear()
|
||||
|
||||
expect(result).toBe(registry)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry Rebuilding', () => {
|
||||
it('should rebuild registry when provider added', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should rebuild registry when provider removed', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
vi.clearAllMocks()
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should set registry to null when all providers removed', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(() => registry.languageModel('any|model' as any)).toThrow('No providers registered')
|
||||
})
|
||||
|
||||
it('should rebuild with correct separator', () => {
|
||||
const customRegistry = new RegistryManagement({ separator: ':' })
|
||||
customRegistry.registerProvider('openai', mockProvider)
|
||||
|
||||
expect(createProviderRegistry).toHaveBeenCalledWith(expect.any(Object), { separator: ':' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('Global Registry Instance', () => {
|
||||
it('should have a global instance with default separator', async () => {
|
||||
const module = await import('../RegistryManagement')
|
||||
|
||||
expect(module.globalRegistryManagement).toBeInstanceOf(RegistryManagement)
|
||||
})
|
||||
|
||||
it('should have DEFAULT_SEPARATOR exported', () => {
|
||||
expect(DEFAULT_SEPARATOR).toBe('|')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle registering same provider twice', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const provider2 = { ...mockProvider }
|
||||
registry.registerProvider('openai', provider2)
|
||||
|
||||
expect(registry.getProvider('openai')).toBe(provider2)
|
||||
})
|
||||
|
||||
it('should handle alias conflicts (first wins)', () => {
|
||||
registry.registerProvider('provider1', mockProvider, ['shared-alias'])
|
||||
registry.registerProvider('provider2', mockProvider, ['shared-alias'])
|
||||
|
||||
// First registered alias wins (the implementation doesn't override)
|
||||
expect(registry.resolveProviderId('shared-alias')).toBe('provider1')
|
||||
})
|
||||
|
||||
it('should handle empty alias array', () => {
|
||||
registry.registerProvider('openai', mockProvider, [])
|
||||
|
||||
expect(registry.getAllAliases()).toEqual({})
|
||||
})
|
||||
|
||||
it('should handle null registry operations gracefully', () => {
|
||||
const emptyRegistry = new RegistryManagement()
|
||||
|
||||
expect(() => emptyRegistry.languageModel('test|model' as any)).toThrow('No providers registered')
|
||||
expect(() => emptyRegistry.embeddingModel('test|embed' as any)).toThrow('No providers registered')
|
||||
expect(() => emptyRegistry.imageModel('test|image' as any)).toThrow('No providers registered')
|
||||
})
|
||||
|
||||
it('should handle special characters in provider IDs', () => {
|
||||
const specialIds = ['provider-1', 'provider_2', 'provider.3', 'provider:4']
|
||||
|
||||
specialIds.forEach((id) => {
|
||||
registry.registerProvider(id, mockProvider)
|
||||
expect(registry.getProvider(id)).toBe(mockProvider)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Concurrent Operations', () => {
|
||||
it('should handle concurrent registrations', () => {
|
||||
const promises = [
|
||||
Promise.resolve(registry.registerProvider('provider1', mockProvider)),
|
||||
Promise.resolve(registry.registerProvider('provider2', mockProvider)),
|
||||
Promise.resolve(registry.registerProvider('provider3', mockProvider))
|
||||
]
|
||||
|
||||
return Promise.all(promises).then(() => {
|
||||
expect(registry.getRegisteredProviders()).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle mixed operations', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
registry.registerProvider('anthropic', mockProvider)
|
||||
|
||||
const provider1 = registry.getProvider('openai')
|
||||
registry.unregisterProvider('anthropic')
|
||||
const provider2 = registry.getProvider('openai')
|
||||
|
||||
expect(provider1).toBe(provider2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should enforce model ID format with template literal types', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
// These should be type-safe
|
||||
const validId = 'openai|gpt-4' as `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
expect(() => registry.languageModel(validId)).not.toThrow()
|
||||
})
|
||||
|
||||
it('should return properly typed LanguageModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.languageModel('openai|gpt-4' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doGenerate')
|
||||
expect(model).toHaveProperty('doStream')
|
||||
})
|
||||
|
||||
it('should return properly typed EmbeddingModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.embeddingModel('openai|ada-002' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doEmbed')
|
||||
})
|
||||
|
||||
it('should return properly typed ImageModelV3', () => {
|
||||
registry.registerProvider('openai', mockProvider)
|
||||
|
||||
const model = registry.imageModel('openai|dall-e-3' as any)
|
||||
|
||||
expect(model.specificationVersion).toBe('v3')
|
||||
expect(model).toHaveProperty('doGenerate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Memory Management', () => {
|
||||
it('should properly clean up on clear', () => {
|
||||
registry.registerProvider('p1', mockProvider, ['a1'])
|
||||
registry.registerProvider('p2', mockProvider, ['a2'])
|
||||
|
||||
registry.clear()
|
||||
|
||||
expect(registry.getRegisteredProviders()).toHaveLength(0)
|
||||
expect(Object.keys(registry.getAllAliases())).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should properly clean up on unregister', () => {
|
||||
registry.registerProvider('openai', mockProvider, ['gpt', 'chatgpt'])
|
||||
|
||||
registry.unregisterProvider('openai')
|
||||
|
||||
expect(registry.getProvider('openai')).toBeUndefined()
|
||||
expect(registry.isAlias('gpt')).toBe(false)
|
||||
expect(registry.isAlias('chatgpt')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,662 +0,0 @@
|
||||
/**
|
||||
* 测试真正的 AiProviderRegistry 功能
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// 模拟 AI SDK
|
||||
vi.mock('@ai-sdk/openai', () => ({
|
||||
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/anthropic', () => ({
|
||||
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/azure', () => ({
|
||||
createAzure: vi.fn(() => ({
|
||||
name: 'azure-mock',
|
||||
languageModel: vi.fn((modelId: string) => ({ mode: 'default', modelId })),
|
||||
chat: vi.fn((modelId: string) => ({ mode: 'chat', modelId })),
|
||||
responses: vi.fn((modelId: string) => ({ mode: 'responses', modelId })),
|
||||
embeddingModel: vi.fn(),
|
||||
imageModel: vi.fn(),
|
||||
transcriptionModel: vi.fn(),
|
||||
speechModel: vi.fn()
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/deepseek', () => ({
|
||||
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/google', () => ({
|
||||
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/openai-compatible', () => ({
|
||||
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/xai', () => ({
|
||||
createXai: vi.fn(() => ({ name: 'xai-mock' }))
|
||||
}))
|
||||
|
||||
import {
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
hasInitializedProviders,
|
||||
hasProviderConfig,
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
ProviderInitializationError,
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from '../registry'
|
||||
import type { ProviderConfig } from '../schemas'
|
||||
|
||||
describe('Provider Registry 功能测试', () => {
|
||||
beforeEach(() => {
|
||||
// 清理状态
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('基础功能', () => {
|
||||
it('能够获取支持的 providers 列表', () => {
|
||||
const providers = getSupportedProviders()
|
||||
expect(Array.isArray(providers)).toBe(true)
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
|
||||
// 检查返回的数据结构
|
||||
providers.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
})
|
||||
|
||||
// 包含基础 providers
|
||||
const providerIds = providers.map((p) => p.id)
|
||||
expect(providerIds).toContain('openai')
|
||||
expect(providerIds).toContain('anthropic')
|
||||
expect(providerIds).toContain('google')
|
||||
})
|
||||
|
||||
it('能够获取已初始化的 providers', () => {
|
||||
// 初始状态下没有已初始化的 providers
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够访问全局注册管理器', () => {
|
||||
expect(providerRegistry).toBeDefined()
|
||||
expect(typeof providerRegistry.clear).toBe('function')
|
||||
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||
})
|
||||
|
||||
it('能够获取语言模型', () => {
|
||||
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
|
||||
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 配置注册', () => {
|
||||
it('能够注册自定义 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||
})
|
||||
|
||||
it('能够注册带别名的 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider-with-aliases',
|
||||
name: 'Custom Provider with Aliases',
|
||||
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'alias-2']
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||
})
|
||||
|
||||
it('拒绝无效的配置', () => {
|
||||
// 缺少必要字段
|
||||
const invalidConfig = {
|
||||
id: 'invalid-provider'
|
||||
// 缺少 name, creator 等
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(invalidConfig as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('能够批量注册 provider 配置', () => {
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-1',
|
||||
name: 'Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'provider-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider-2',
|
||||
name: 'Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'provider-2' })),
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2) // 只有前两个成功
|
||||
|
||||
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('能够获取所有配置和别名信息', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['test-alias']
|
||||
})
|
||||
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(Array.isArray(allConfigs)).toBe(true)
|
||||
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['test-alias']).toBe('test-provider')
|
||||
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 创建和注册', () => {
|
||||
it('能够创建 provider 实例', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-provider',
|
||||
name: 'Test Create Provider',
|
||||
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 创建 provider 实例
|
||||
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||
expect(provider).toBeDefined()
|
||||
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
})
|
||||
|
||||
it('creates the Azure provider with chat as the default language model', async () => {
|
||||
const provider = await createProvider('azure', { apiKey: 'test' })
|
||||
|
||||
expect(provider.languageModel('gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
|
||||
})
|
||||
|
||||
it('creates the Azure responses provider with responses as the default language model', async () => {
|
||||
const provider = await createProvider('azure-responses', { apiKey: 'test' })
|
||||
|
||||
expect(provider.languageModel('gpt-5.4')).toEqual({ mode: 'responses', modelId: 'gpt-5.4' })
|
||||
})
|
||||
|
||||
it('能够注册 provider 到全局管理器', () => {
|
||||
const mockProvider = { name: 'mock-provider' }
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-register-provider',
|
||||
name: 'Test Register Provider',
|
||||
creator: vi.fn(() => mockProvider),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 注册 provider 到全局管理器
|
||||
const success = registerProvider('test-register-provider', mockProvider)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-register-provider')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('registers Azure chat lookups to the chat language model', async () => {
|
||||
const success = await createAndRegisterProvider('azure', { apiKey: 'test' })
|
||||
|
||||
expect(success).toBe(true)
|
||||
expect(getInitializedProviders()).toEqual(expect.arrayContaining(['azure', 'azure-chat']))
|
||||
expect(getLanguageModel('azure|gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
|
||||
expect(getLanguageModel('azure-chat|gpt-5.4')).toEqual({ mode: 'chat', modelId: 'gpt-5.4' })
|
||||
})
|
||||
|
||||
it('能够一步完成创建和注册', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-and-register',
|
||||
name: 'Test Create and Register',
|
||||
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 一步完成创建和注册
|
||||
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-create-and-register')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry 管理', () => {
|
||||
it('能够清理所有配置和注册的 providers', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'temp-provider',
|
||||
name: 'Temp Provider',
|
||||
creator: vi.fn(() => ({ name: 'temp' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||
// 但基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||
})
|
||||
|
||||
it('能够单独清理已注册的 providers', () => {
|
||||
// 清理所有 providers
|
||||
clearAllProviders()
|
||||
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('ProviderInitializationError 错误类工作正常', () => {
|
||||
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||
expect(error.message).toBe('Test error')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.name).toBe('ProviderInitializationError')
|
||||
})
|
||||
})
|
||||
|
||||
describe('错误处理', () => {
|
||||
it('优雅处理空配置', () => {
|
||||
const success = registerProviderConfig(null as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('优雅处理未定义配置', () => {
|
||||
const success = registerProviderConfig(undefined as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理空字符串 ID', () => {
|
||||
const config = {
|
||||
id: '',
|
||||
name: 'Empty ID Provider',
|
||||
creator: vi.fn(() => ({ name: 'empty' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理创建不存在配置的 provider', async () => {
|
||||
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||
'ProviderConfig not found for id: non-existent-provider'
|
||||
)
|
||||
})
|
||||
|
||||
it('处理注册不存在配置的 provider', () => {
|
||||
const mockProvider = { name: 'mock' }
|
||||
const success = registerProvider('non-existent-provider', mockProvider)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理获取不存在配置的情况', () => {
|
||||
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理批量注册时的部分失败', () => {
|
||||
const mixedConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'valid-provider-1',
|
||||
name: 'Valid Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'valid-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any,
|
||||
{
|
||||
id: 'valid-provider-2',
|
||||
name: 'Valid Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'valid-2' })),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||
|
||||
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理动态导入失败的情况', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'import-test-provider',
|
||||
name: 'Import Test Provider',
|
||||
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||
creatorFunctionName: 'createTest',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
registerProviderConfig(config)
|
||||
|
||||
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('集成测试', () => {
|
||||
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||
// 初始状态验证
|
||||
const initialConfigs = getAllProviderConfigs()
|
||||
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
|
||||
// 注册多个带别名的 provider 配置
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'integration-provider-1',
|
||||
name: 'Integration Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'short-name-1']
|
||||
},
|
||||
{
|
||||
id: 'integration-provider-2',
|
||||
name: 'Integration Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['alias-2', 'short-name-2']
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2)
|
||||
|
||||
// 验证配置注册成功
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
|
||||
// 验证别名映射
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||
|
||||
// 创建和注册 providers
|
||||
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||
expect(success1).toBe(true)
|
||||
expect(success2).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('integration-provider-1')
|
||||
expect(registeredProviders).toContain('integration-provider-2')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
// 验证清理后的状态
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||
expect(getAllProviderConfigAliases()).toEqual({})
|
||||
|
||||
// 基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('正确处理动态导入配置的 provider', async () => {
|
||||
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||
const dynamicImportConfig: ProviderConfig = {
|
||||
id: 'dynamic-import-provider',
|
||||
name: 'Dynamic Import Provider',
|
||||
import: vi.fn().mockResolvedValue(mockModule),
|
||||
creatorFunctionName: 'createCustomProvider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 注册配置
|
||||
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||
expect(configSuccess).toBe(true)
|
||||
|
||||
// 创建和注册 provider
|
||||
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||
expect(registerSuccess).toBe(true)
|
||||
|
||||
// 验证导入函数被调用
|
||||
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
|
||||
// 验证注册成功
|
||||
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||
})
|
||||
|
||||
it('正确处理大量配置的注册和管理', () => {
|
||||
const largeConfigList: ProviderConfig[] = []
|
||||
|
||||
// 生成50个配置
|
||||
for (let i = 0; i < 50; i++) {
|
||||
largeConfigList.push({
|
||||
id: `bulk-provider-${i}`,
|
||||
name: `Bulk Provider ${i}`,
|
||||
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||
aliases: [`alias-${i}`, `short-${i}`]
|
||||
})
|
||||
}
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||
expect(successCount).toBe(50)
|
||||
|
||||
// 验证所有配置都被正确注册
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||
|
||||
// 验证别名数量
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
const bulkAliases = Object.keys(aliases).filter(
|
||||
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||
)
|
||||
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||
|
||||
// 随机验证几个配置
|
||||
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||
|
||||
// 验证别名工作正常
|
||||
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||
|
||||
// 清理能正确处理大量数据
|
||||
cleanup()
|
||||
const cleanupAliases = getAllProviderConfigAliases()
|
||||
expect(
|
||||
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界测试', () => {
|
||||
it('处理包含特殊字符的 provider IDs', () => {
|
||||
const specialCharsConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-with-dashes',
|
||||
name: 'Provider With Dashes',
|
||||
creator: vi.fn(() => ({ name: 'dashes' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider_with_underscores',
|
||||
name: 'Provider With Underscores',
|
||||
creator: vi.fn(() => ({ name: 'underscores' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider.with.dots',
|
||||
name: 'Provider With Dots',
|
||||
creator: vi.fn(() => ({ name: 'dots' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||
|
||||
// 验证支持的特殊字符格式
|
||||
if (hasProviderConfig('provider-with-dashes')) {
|
||||
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||
}
|
||||
if (hasProviderConfig('provider_with_underscores')) {
|
||||
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||
}
|
||||
})
|
||||
|
||||
it('处理空的批量注册', () => {
|
||||
const successCount = registerMultipleProviderConfigs([])
|
||||
expect(successCount).toBe(0)
|
||||
|
||||
// 确保没有额外的配置被添加
|
||||
const configsBefore = getAllProviderConfigs().length
|
||||
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||
})
|
||||
|
||||
it('处理重复的配置注册', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'duplicate-test-provider',
|
||||
name: 'Duplicate Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 第一次注册成功
|
||||
expect(registerProviderConfig(config)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 重复注册相同的配置(允许覆盖)
|
||||
const updatedConfig: ProviderConfig = {
|
||||
...config,
|
||||
name: 'Updated Duplicate Test Provider'
|
||||
}
|
||||
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 验证配置被更新
|
||||
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||
})
|
||||
|
||||
it('处理极长的 ID 和名称', () => {
|
||||
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: longId,
|
||||
name: longName,
|
||||
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
expect(hasProviderConfig(longId)).toBe(true)
|
||||
|
||||
const retrievedConfig = getProviderConfig(longId)
|
||||
expect(retrievedConfig?.name).toBe(longName)
|
||||
})
|
||||
|
||||
it('处理大量别名的配置', () => {
|
||||
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: 'provider-with-many-aliases',
|
||||
name: 'Provider With Many Aliases',
|
||||
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: manyAliases
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证所有别名都能正确解析
|
||||
manyAliases.forEach((alias) => {
|
||||
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,269 +0,0 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
type BaseProviderId,
|
||||
baseProviderIds,
|
||||
baseProviderIdSchema,
|
||||
baseProviders,
|
||||
type CustomProviderId,
|
||||
customProviderIdSchema,
|
||||
providerConfigSchema,
|
||||
type ProviderId,
|
||||
providerIdSchema
|
||||
} from '../schemas'
|
||||
|
||||
describe('Provider Schemas', () => {
|
||||
describe('baseProviders', () => {
|
||||
it('包含所有预期的基础 providers', () => {
|
||||
expect(baseProviders).toBeDefined()
|
||||
expect(Array.isArray(baseProviders)).toBe(true)
|
||||
expect(baseProviders.length).toBeGreaterThan(0)
|
||||
|
||||
// These are the actual base providers defined in schemas.ts
|
||||
const expectedIds = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'deepseek',
|
||||
'openrouter',
|
||||
'cherryin',
|
||||
'cherryin-chat'
|
||||
]
|
||||
const actualIds = baseProviders.map((p) => p.id)
|
||||
expectedIds.forEach((id) => {
|
||||
expect(actualIds).toContain(id)
|
||||
})
|
||||
})
|
||||
|
||||
it('每个基础 provider 有必要的属性', () => {
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(provider).toHaveProperty('creator')
|
||||
expect(provider).toHaveProperty('supportsImageGeneration')
|
||||
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
expect(typeof provider.creator).toBe('function')
|
||||
expect(typeof provider.supportsImageGeneration).toBe('boolean')
|
||||
})
|
||||
})
|
||||
|
||||
it('provider ID 是唯一的', () => {
|
||||
const ids = baseProviders.map((p) => p.id)
|
||||
const uniqueIds = [...new Set(ids)]
|
||||
expect(ids).toEqual(uniqueIds)
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIds', () => {
|
||||
it('正确提取所有基础 provider IDs', () => {
|
||||
expect(baseProviderIds).toBeDefined()
|
||||
expect(Array.isArray(baseProviderIds)).toBe(true)
|
||||
expect(baseProviderIds.length).toBe(baseProviders.length)
|
||||
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(baseProviderIds).toContain(provider.id)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIdSchema', () => {
|
||||
it('验证有效的基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的基础 provider IDs', () => {
|
||||
const invalidIds = ['invalid', 'not-exists', '']
|
||||
invalidIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('customProviderIdSchema', () => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝空字符串', () => {
|
||||
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerIdSchema', () => {
|
||||
it('接受基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||
validCustomIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = ['', undefined, null, 123]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerConfigSchema', () => {
|
||||
it('验证带有 creator 的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('验证带有 import 配置的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
import: vi.fn(),
|
||||
creatorFunctionName: 'createCustom',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'invalid',
|
||||
name: 'Invalid Provider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('为 supportsImageGeneration 设置默认值', () => {
|
||||
const config = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
creator: vi.fn()
|
||||
}
|
||||
const result = providerConfigSchema.safeParse(config)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.data.supportsImageGeneration).toBe(false)
|
||||
}
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 基础 provider ID
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('拒绝缺少必需字段的配置', () => {
|
||||
const invalidConfigs = [
|
||||
{ name: 'Missing ID', creator: vi.fn() },
|
||||
{ id: 'missing-name', creator: vi.fn() },
|
||||
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||
]
|
||||
|
||||
invalidConfigs.forEach((config) => {
|
||||
expect(providerConfigSchema.safeParse(config).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Schema 验证功能', () => {
|
||||
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||
})
|
||||
|
||||
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
customIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 拒绝基础 provider IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||
// 基础 IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 自定义 IDs
|
||||
const customIds = ['custom-provider', 'my-service']
|
||||
customIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 不允许基础 provider ID
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('类型推导', () => {
|
||||
it('BaseProviderId 类型正确', () => {
|
||||
const id: BaseProviderId = 'openai'
|
||||
expect(baseProviderIds).toContain(id)
|
||||
})
|
||||
|
||||
it('CustomProviderId 类型是字符串', () => {
|
||||
const id: CustomProviderId = 'custom-provider'
|
||||
expect(typeof id).toBe('string')
|
||||
})
|
||||
|
||||
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||
const baseId: ProviderId = 'openai'
|
||||
const customId: ProviderId = 'custom-provider'
|
||||
expect(typeof baseId).toBe('string')
|
||||
expect(typeof customId).toBe('string')
|
||||
})
|
||||
})
|
||||
})
|
||||
269
packages/aiCore/src/core/providers/__tests__/types.test.ts
Normal file
269
packages/aiCore/src/core/providers/__tests__/types.test.ts
Normal file
@@ -0,0 +1,269 @@
|
||||
/**
|
||||
* Provider Types - Type-level Tests
|
||||
* Tests type utilities and type inference for provider extensions
|
||||
*/
|
||||
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { describe, expectTypeOf, it } from 'vitest'
|
||||
|
||||
import type { ProviderExtensionConfig } from '../core/ProviderExtension'
|
||||
import type {
|
||||
CoreProviderSettingsMap,
|
||||
ExtensionConfigToIdResolutionMap,
|
||||
ExtractExtensionIds,
|
||||
ExtractProviderIds,
|
||||
StringKeys,
|
||||
UnionToIntersection
|
||||
} from '../types'
|
||||
|
||||
describe('Type Utilities', () => {
|
||||
describe('StringKeys<T>', () => {
|
||||
it('should extract only string keys from object type', () => {
|
||||
type TestObj = { foo: 1; bar: 2; 0: 3; 1: 4 }
|
||||
type Result = StringKeys<TestObj>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'foo' | 'bar'>()
|
||||
})
|
||||
|
||||
it('should return never for object with no string keys', () => {
|
||||
type TestObj = { 0: 'a'; 1: 'b' }
|
||||
type Result = StringKeys<TestObj>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<never>()
|
||||
})
|
||||
|
||||
it('should handle empty object', () => {
|
||||
type Result = StringKeys<{}>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<never>()
|
||||
})
|
||||
|
||||
it('should preserve literal string keys', () => {
|
||||
type TestObj = { openai: 1; anthropic: 2; google: 3 }
|
||||
type Result = StringKeys<TestObj>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'anthropic' | 'google'>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('UnionToIntersection<U>', () => {
|
||||
it('should convert union to intersection', () => {
|
||||
type Union = { a: 1 } | { b: 2 }
|
||||
type Result = UnionToIntersection<Union>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{ a: 1 } & { b: 2 }>()
|
||||
})
|
||||
|
||||
it('should handle single type', () => {
|
||||
type Single = { a: 1 }
|
||||
type Result = UnionToIntersection<Single>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{ a: 1 }>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExtractProviderIds<TConfig>', () => {
|
||||
it('should extract base name', () => {
|
||||
type Config = { name: 'openai' }
|
||||
type Result = ExtractProviderIds<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'openai'>()
|
||||
})
|
||||
|
||||
it('should extract name and aliases', () => {
|
||||
type Config = { name: 'anthropic'; aliases: readonly ['claude'] }
|
||||
type Result = ExtractProviderIds<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'anthropic' | 'claude'>()
|
||||
})
|
||||
|
||||
it('should extract name and variants', () => {
|
||||
type Config = { name: 'openai'; variants: readonly [{ suffix: 'chat' }] }
|
||||
type Result = ExtractProviderIds<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'openai-chat'>()
|
||||
})
|
||||
|
||||
it('should extract name, aliases, and variants', () => {
|
||||
type Config = {
|
||||
name: 'azure'
|
||||
aliases: readonly ['azure-openai']
|
||||
variants: readonly [{ suffix: 'responses' }]
|
||||
}
|
||||
type Result = ExtractProviderIds<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'azure' | 'azure-openai' | 'azure-responses'>()
|
||||
})
|
||||
|
||||
it('should handle multiple variants', () => {
|
||||
type Config = {
|
||||
name: 'openai'
|
||||
variants: readonly [{ suffix: 'chat' }, { suffix: 'responses' }]
|
||||
}
|
||||
type Result = ExtractProviderIds<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'openai' | 'openai-chat' | 'openai-responses'>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExtensionConfigToIdResolutionMap<TConfig>', () => {
|
||||
it('should map base name to itself', () => {
|
||||
type Config = { name: 'openai' }
|
||||
type Result = ExtensionConfigToIdResolutionMap<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{ readonly openai: 'openai' }>()
|
||||
})
|
||||
|
||||
it('should map aliases to base name', () => {
|
||||
type Config = { name: 'anthropic'; aliases: readonly ['claude'] }
|
||||
type Result = ExtensionConfigToIdResolutionMap<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{
|
||||
readonly anthropic: 'anthropic'
|
||||
readonly claude: 'anthropic'
|
||||
}>()
|
||||
})
|
||||
|
||||
it('should map variants to themselves (self-referential)', () => {
|
||||
type Config = { name: 'azure'; variants: readonly [{ suffix: 'responses' }] }
|
||||
type Result = ExtensionConfigToIdResolutionMap<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{
|
||||
readonly azure: 'azure'
|
||||
readonly 'azure-responses': 'azure-responses'
|
||||
}>()
|
||||
})
|
||||
|
||||
it('should handle combined aliases and variants correctly', () => {
|
||||
type Config = {
|
||||
name: 'azure'
|
||||
aliases: readonly ['azure-openai']
|
||||
variants: readonly [{ suffix: 'responses' }]
|
||||
}
|
||||
type Result = ExtensionConfigToIdResolutionMap<Config>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{
|
||||
readonly azure: 'azure'
|
||||
readonly 'azure-openai': 'azure'
|
||||
readonly 'azure-responses': 'azure-responses'
|
||||
}>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExtractExtensionIds<T>', () => {
|
||||
it('should extract IDs from extension with config property', () => {
|
||||
type MockExtension = {
|
||||
config: { name: 'test'; aliases: readonly ['test-alias'] }
|
||||
}
|
||||
type Result = ExtractExtensionIds<MockExtension>
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<'test' | 'test-alias'>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ExtensionToSettingsMap<T>', () => {
|
||||
it('should map provider IDs to settings type', () => {
|
||||
type MockSettings = { apiKey: string }
|
||||
type MockConfig = { name: 'mock' }
|
||||
|
||||
// This tests the concept - actual implementation depends on ProviderExtension structure
|
||||
type Result = { [K in ExtractProviderIds<MockConfig>]: MockSettings }
|
||||
|
||||
expectTypeOf<Result>().toEqualTypeOf<{ mock: MockSettings }>()
|
||||
})
|
||||
})
|
||||
|
||||
describe('CoreProviderSettingsMap', () => {
|
||||
it('should include openai provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openai')
|
||||
})
|
||||
|
||||
it('should include anthropic provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('anthropic')
|
||||
})
|
||||
|
||||
it('should include google provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('google')
|
||||
})
|
||||
|
||||
it('should include azure provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('azure')
|
||||
})
|
||||
|
||||
it('should include xai provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('xai')
|
||||
})
|
||||
|
||||
it('should include deepseek provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('deepseek')
|
||||
})
|
||||
|
||||
it('should include openrouter provider', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openrouter')
|
||||
})
|
||||
|
||||
it('should include aliases like claude', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('claude')
|
||||
})
|
||||
|
||||
it('should include variants like openai-chat', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('openai-chat')
|
||||
})
|
||||
|
||||
it('should include variants like azure-responses', () => {
|
||||
expectTypeOf<CoreProviderSettingsMap>().toHaveProperty('azure-responses')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ProviderExtensionConfig Type Constraints', () => {
|
||||
it('should accept valid minimal config', () => {
|
||||
type ValidConfig = ProviderExtensionConfig<{ apiKey: string }, ProviderV3, 'test'>
|
||||
|
||||
// Should compile without errors
|
||||
const config: ValidConfig = {
|
||||
name: 'test',
|
||||
create: () => ({}) as ProviderV3
|
||||
}
|
||||
|
||||
expectTypeOf(config.name).toEqualTypeOf<'test'>()
|
||||
})
|
||||
|
||||
it('should accept config with aliases', () => {
|
||||
type ConfigWithAliases = {
|
||||
name: 'anthropic'
|
||||
aliases: readonly ['claude']
|
||||
create: () => ProviderV3
|
||||
}
|
||||
|
||||
const config: ConfigWithAliases = {
|
||||
name: 'anthropic',
|
||||
aliases: ['claude'] as const,
|
||||
create: () => ({}) as ProviderV3
|
||||
}
|
||||
|
||||
expectTypeOf(config.aliases).toEqualTypeOf<readonly ['claude']>()
|
||||
})
|
||||
|
||||
it('should accept config with variants', () => {
|
||||
type ConfigWithVariants = {
|
||||
name: 'openai'
|
||||
variants: readonly [{ suffix: 'chat'; name: string; transform: (p: ProviderV3) => ProviderV3 }]
|
||||
create: () => ProviderV3
|
||||
}
|
||||
|
||||
const config: ConfigWithVariants = {
|
||||
name: 'openai',
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
transform: (p) => p
|
||||
}
|
||||
] as const,
|
||||
create: () => ({}) as ProviderV3
|
||||
}
|
||||
|
||||
expectTypeOf(config.variants[0].suffix).toEqualTypeOf<'chat'>()
|
||||
})
|
||||
})
|
||||
513
packages/aiCore/src/core/providers/core/ExtensionRegistry.ts
Normal file
513
packages/aiCore/src/core/providers/core/ExtensionRegistry.ts
Normal file
@@ -0,0 +1,513 @@
|
||||
/**
|
||||
* Extension Registry
|
||||
* 管理所有 Provider Extensions 的注册、查询和实例化
|
||||
*/
|
||||
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
|
||||
import type { CoreProviderSettingsMap, RegisteredProviderId, ToolCapability, ToolFactory } from '../index'
|
||||
import { type ProviderExtension } from './ProviderExtension'
|
||||
import { ProviderCreationError } from './utils'
|
||||
|
||||
/**
|
||||
* Provider Extension 注册表
|
||||
*
|
||||
* 职责:
|
||||
* - 注册和管理 Provider Extensions
|
||||
* - 根据 ID 查找对应的 Extension
|
||||
* - 创建并注册 provider 实例(包括变体)
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { extensionRegistry } from '@cherrystudio/ai-core/provider'
|
||||
* import { OpenAIExtension } from './extensions/openai'
|
||||
*
|
||||
* // 注册 extension
|
||||
* extensionRegistry.register(OpenAIExtension)
|
||||
*
|
||||
* // 批量注册
|
||||
* extensionRegistry.registerAll([
|
||||
* OpenAIExtension,
|
||||
* AzureExtension,
|
||||
* AnthropicExtension
|
||||
* ])
|
||||
*
|
||||
* // 创建并注册 provider 实例
|
||||
* await extensionRegistry.createAndRegisterProvider('openai', {
|
||||
* apiKey: 'sk-xxx'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export class ExtensionRegistry {
|
||||
/** Extension 存储: name -> Extension */
|
||||
private extensions: Map<string, ProviderExtension<any, any, any>> = new Map()
|
||||
|
||||
/** 别名映射: alias -> name */
|
||||
private aliasMap: Map<string, string> = new Map()
|
||||
|
||||
/**
|
||||
* 注册单个 Extension
|
||||
* 支持链式调用
|
||||
*/
|
||||
register(extension: ProviderExtension<any, any, any>): this {
|
||||
const { name, aliases, variants } = extension.config
|
||||
|
||||
// Idempotent: skip if already registered (supports HMR / re-import)
|
||||
if (this.extensions.has(name)) {
|
||||
return this
|
||||
}
|
||||
|
||||
this.extensions.set(name, extension)
|
||||
|
||||
if (aliases) {
|
||||
for (const alias of aliases) {
|
||||
if (this.aliasMap.has(alias)) {
|
||||
throw new Error(`Provider alias "${alias}" is already registered for "${this.aliasMap.get(alias)}"`)
|
||||
}
|
||||
this.aliasMap.set(alias, name)
|
||||
}
|
||||
}
|
||||
|
||||
if (variants) {
|
||||
for (const variant of variants) {
|
||||
const variantId = `${name}-${variant.suffix}`
|
||||
if (this.aliasMap.has(variantId)) {
|
||||
throw new Error(
|
||||
`Provider variant ID "${variantId}" is already registered for "${this.aliasMap.get(variantId)}"`
|
||||
)
|
||||
}
|
||||
this.aliasMap.set(variantId, name)
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册 Extensions
|
||||
* 支持 readonly 数组(用于 as const 数组)
|
||||
*/
|
||||
registerAll(extensions: readonly ProviderExtension<any, any, any>[]): this {
|
||||
for (const ext of extensions) {
|
||||
this.register(ext)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 取消注册 Extension
|
||||
*/
|
||||
unregister(name: string): boolean {
|
||||
const extension = this.extensions.get(name)
|
||||
if (!extension) {
|
||||
return false
|
||||
}
|
||||
|
||||
extension.clearCache()
|
||||
this.extensions.delete(name)
|
||||
|
||||
if (extension.config.aliases) {
|
||||
for (const alias of extension.config.aliases) {
|
||||
this.aliasMap.delete(alias)
|
||||
}
|
||||
}
|
||||
|
||||
if (extension.config.variants) {
|
||||
for (const variant of extension.config.variants) {
|
||||
this.aliasMap.delete(`${name}-${variant.suffix}`)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Extension(支持别名)
|
||||
*/
|
||||
get(id: string): ProviderExtension<any, any, any> | undefined {
|
||||
if (this.extensions.has(id)) {
|
||||
return this.extensions.get(id)
|
||||
}
|
||||
|
||||
const realName = this.aliasMap.get(id)
|
||||
if (realName) {
|
||||
return this.extensions.get(realName)
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Extension
|
||||
*
|
||||
* @param id - Provider ID(必须是 RegisteredProviderId)
|
||||
* @returns Extension 或 undefined
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const ext = extensionRegistry.getTyped('openai')
|
||||
* if (ext) {
|
||||
* const provider = await ext.createProvider({
|
||||
* apiKey: 'sk-...'
|
||||
* })
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
getTyped<T extends RegisteredProviderId>(id: T): ProviderExtension<any, any, any> | undefined {
|
||||
return this.get(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 Extension 是否已注册
|
||||
*/
|
||||
has(id: string): boolean {
|
||||
return this.extensions.has(id) || this.aliasMap.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已注册的 Extension
|
||||
*/
|
||||
getAll(): ProviderExtension<any, any, any>[] {
|
||||
return Array.from(this.extensions.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已注册的 provider IDs(包含变体)
|
||||
* 返回类型安全的 RegisteredProviderId 数组,自动去重
|
||||
*/
|
||||
getAllProviderIds(): RegisteredProviderId[] {
|
||||
const ids = new Set<string>()
|
||||
|
||||
for (const extension of this.extensions.values()) {
|
||||
for (const id of extension.getProviderIds()) {
|
||||
ids.add(id)
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(ids) as RegisteredProviderId[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据 base ID + mode 解析到完整的 provider ID
|
||||
*
|
||||
* 支持别名:如果 baseId 是别名,会先解析到规范 ID
|
||||
*
|
||||
* @param baseId - 基础 provider ID(可以是别名)
|
||||
* @param mode - 模式(如 'chat', 'responses')
|
||||
* @returns 完整的 provider ID,如果无法解析则返回 null
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* resolveProviderIdWithMode('openai', 'chat') // → 'openai-chat'
|
||||
* resolveProviderIdWithMode('azure', 'responses') // → 'azure-responses'
|
||||
* resolveProviderIdWithMode('gemini', 'chat') // → null (google 没有 chat 变体)
|
||||
* resolveProviderIdWithMode('openai') // → 'openai' (没有 mode)
|
||||
* ```
|
||||
*/
|
||||
resolveProviderIdWithMode(baseId: string, mode?: string): string | null {
|
||||
// 如果没有 mode,直接返回解析后的 ID
|
||||
if (!mode) {
|
||||
const extension = this.get(baseId)
|
||||
return extension ? extension.config.name : null
|
||||
}
|
||||
|
||||
// 获取 extension(支持别名)
|
||||
const extension = this.get(baseId)
|
||||
if (!extension) {
|
||||
return null
|
||||
}
|
||||
|
||||
// 检查是否有对应的变体
|
||||
if (!extension.config.variants) {
|
||||
return null
|
||||
}
|
||||
|
||||
// 查找匹配的变体
|
||||
const variant = extension.config.variants.find((v: { suffix: string }) => v.suffix === mode)
|
||||
if (!variant) {
|
||||
return null
|
||||
}
|
||||
|
||||
// 返回变体 ID: ${name}-${suffix}
|
||||
return `${extension.config.name}-${variant.suffix}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 反向解析:从完整 ID 提取 base ID 和 mode
|
||||
*
|
||||
* 遍历所有 extensions 的变体,匹配 `${name}-${suffix}` 模式
|
||||
*
|
||||
* @param providerId - 完整的 provider ID
|
||||
* @returns 解析结果,如果无法解析返回 null
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* parseProviderId('openai-chat') // → { baseId: 'openai', mode: 'chat', isVariant: true }
|
||||
* parseProviderId('azure-responses') // → { baseId: 'azure', mode: 'responses', isVariant: true }
|
||||
* parseProviderId('openai') // → { baseId: 'openai', isVariant: false }
|
||||
* parseProviderId('oai') // → { baseId: 'openai', isVariant: false } (别名)
|
||||
* parseProviderId('unknown') // → null
|
||||
* ```
|
||||
*/
|
||||
parseProviderId(providerId: string): { baseId: RegisteredProviderId; mode?: string; isVariant: boolean } | null {
|
||||
// 先遍历所有 extensions,查找匹配的变体(优先于别名检查)
|
||||
for (const ext of this.extensions.values()) {
|
||||
if (!ext.config.variants) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查每个变体
|
||||
for (const variant of ext.config.variants) {
|
||||
const variantId = `${ext.config.name}-${variant.suffix}`
|
||||
if (variantId === providerId) {
|
||||
return {
|
||||
baseId: ext.config.name as RegisteredProviderId,
|
||||
mode: variant.suffix,
|
||||
isVariant: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 再检查是否是已注册的 extension(直接或通过别名)
|
||||
const extension = this.get(providerId)
|
||||
if (extension) {
|
||||
// 是基础 ID 或别名,不是变体
|
||||
return {
|
||||
baseId: extension.config.name as RegisteredProviderId,
|
||||
isVariant: false
|
||||
}
|
||||
}
|
||||
|
||||
// 无法解析
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为变体 ID
|
||||
*
|
||||
* @param id - Provider ID
|
||||
* @returns 如果是变体 ID 返回 true
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* isVariant('openai-chat') // → true
|
||||
* isVariant('azure-responses') // → true
|
||||
* isVariant('openai') // → false
|
||||
* isVariant('unknown') // → false
|
||||
* ```
|
||||
*/
|
||||
isVariant(id: string): boolean {
|
||||
const parsed = this.parseProviderId(id)
|
||||
return parsed?.isVariant ?? false
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取基础 provider ID
|
||||
*
|
||||
* 对于变体ID,返回其基础provider ID;
|
||||
* 对于基础ID或别名,返回规范的provider ID;
|
||||
* 对于未知ID,返回null
|
||||
*
|
||||
* @param id - Provider ID(可以是基础ID、变体ID或别名)
|
||||
* @returns 基础 provider ID,如果无法解析则返回 null
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* getBaseProviderId('openai-chat') // → 'openai' (变体)
|
||||
* getBaseProviderId('azure-responses') // → 'azure' (变体)
|
||||
* getBaseProviderId('openai') // → 'openai' (基础ID)
|
||||
* getBaseProviderId('oai') // → 'openai' (别名)
|
||||
* getBaseProviderId('unknown') // → null
|
||||
* ```
|
||||
*/
|
||||
getBaseProviderId(id: string): RegisteredProviderId | null {
|
||||
const parsed = this.parseProviderId(id)
|
||||
return parsed?.baseId ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取变体的模式/后缀
|
||||
*
|
||||
* @param variantId - 变体 ID
|
||||
* @returns 模式/后缀,如果不是变体则返回 null
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* getVariantMode('openai-chat') // → 'chat'
|
||||
* getVariantMode('azure-responses') // → 'responses'
|
||||
* getVariantMode('openai') // → null (不是变体)
|
||||
* getVariantMode('unknown') // → null
|
||||
* ```
|
||||
*/
|
||||
getVariantMode(variantId: string): string | null {
|
||||
const parsed = this.parseProviderId(variantId)
|
||||
return parsed?.mode ?? null
|
||||
}
|
||||
|
||||
/** 获取 variant 的 resolveModel 函数(类型安全在 extension 声明处保证) */
|
||||
getModelResolver(providerId: string): ((provider: ProviderV3, modelId: string) => any) | undefined {
|
||||
const parsed = this.parseProviderId(providerId)
|
||||
if (!parsed) return undefined
|
||||
|
||||
const extension = this.get(parsed.baseId)
|
||||
if (!extension) return undefined
|
||||
|
||||
// Variant resolveModel(类型安全,在 extension 声明处校验)
|
||||
if (parsed.isVariant && parsed.mode) {
|
||||
const variant = extension.getVariant(parsed.mode)
|
||||
if (variant?.resolveModel) return variant.resolveModel
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取某个基础 provider 的所有变体 IDs
|
||||
*
|
||||
* @param baseId - 基础 provider ID(可以是别名)
|
||||
* @returns 变体 ID 数组,如果没有变体则返回空数组
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* getVariants('openai') // → ['openai-chat']
|
||||
* getVariants('azure') // → ['azure-responses']
|
||||
* getVariants('google') // → ['google-chat']
|
||||
* getVariants('xai') // → [] (没有变体)
|
||||
* getVariants('unknown') // → [] (未注册)
|
||||
* ```
|
||||
*/
|
||||
getVariants(baseId: string): string[] {
|
||||
const extension = this.get(baseId)
|
||||
if (!extension?.config.variants) {
|
||||
return []
|
||||
}
|
||||
|
||||
return extension.config.variants.map((v: { suffix: string }) => `${extension.config.name}-${v.suffix}`)
|
||||
}
|
||||
|
||||
/** 获取指定 provider 的工具工厂(变体优先,回退到 base) */
|
||||
getToolFactory(providerId: string, capability: ToolCapability): ToolFactory | undefined {
|
||||
const parsed = this.parseProviderId(providerId)
|
||||
if (!parsed) return undefined
|
||||
|
||||
const { baseId, mode, isVariant } = parsed
|
||||
const extension = this.get(baseId)
|
||||
if (!extension) return undefined
|
||||
|
||||
// For variants, check variant-level toolFactories first
|
||||
if (isVariant && mode) {
|
||||
const variant = extension.getVariant(mode)
|
||||
if (variant?.toolFactories?.[capability]) {
|
||||
return variant.toolFactories[capability]
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to base extension's toolFactories
|
||||
return extension.config.toolFactories?.[capability]
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析工具能力:返回 factory + provider 实例
|
||||
*
|
||||
* 1. Direct — provider 自己有 toolFactories
|
||||
* 2. Aggregator fallback — 从 model.provider 段解析(如 "aihubmix.google" → google extension)
|
||||
*/
|
||||
async resolveToolCapability(
|
||||
providerId: string,
|
||||
capability: ToolCapability,
|
||||
modelProvider?: string
|
||||
): Promise<{ factory: ToolFactory; provider: ProviderV3 } | undefined> {
|
||||
// 1. Direct: provider 自己有 toolFactories
|
||||
const directFactory = this.getToolFactory(providerId, capability)
|
||||
if (directFactory) {
|
||||
const provider = await this.getToolProvider(providerId)
|
||||
if (provider) return { factory: directFactory, provider }
|
||||
}
|
||||
|
||||
// 2. Aggregator fallback: 从 model.provider 段解析真实 provider
|
||||
// e.g., "aihubmix.google" → try "google" → found via google extension
|
||||
// e.g., "cherryin.gemini" → try "gemini" → found via alias → google extension
|
||||
if (typeof modelProvider === 'string') {
|
||||
const segments = modelProvider.split('.')
|
||||
for (let i = segments.length - 1; i >= 0; i--) {
|
||||
const factory = this.getToolFactory(segments[i], capability)
|
||||
if (factory) {
|
||||
const provider = await this.getToolProvider(segments[i])
|
||||
if (provider) return { factory, provider }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/** Get base provider for .tools extraction (cached or dummy instance) */
|
||||
private async getToolProvider(providerId: string): Promise<ProviderV3 | undefined> {
|
||||
const parsed = this.parseProviderId(providerId)
|
||||
if (!parsed) return undefined
|
||||
|
||||
const extension = this.get(parsed.baseId)
|
||||
if (!extension) return undefined
|
||||
|
||||
const cached = extension.getCachedProvider()
|
||||
if (cached) return cached
|
||||
|
||||
try {
|
||||
return await extension.createProvider({ apiKey: '_tool_descriptor' })
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空所有注册
|
||||
*/
|
||||
clear(): void {
|
||||
this.extensions.clear()
|
||||
this.aliasMap.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 provider 实例
|
||||
*
|
||||
* 支持两种调用方式:
|
||||
* 1. 类型安全版本 - 使用已注册的 provider ID,获得完整的类型推导
|
||||
* 2. 动态版本 - 使用任意字符串 ID,用于测试或动态注册的 provider
|
||||
*
|
||||
* @param id - Provider ID
|
||||
* @param settings - Provider 配置
|
||||
* @returns Provider 实例
|
||||
*/
|
||||
async createProvider<T extends RegisteredProviderId>(id: T, settings: CoreProviderSettingsMap[T]): Promise<ProviderV3>
|
||||
async createProvider(id: string, settings?: unknown): Promise<ProviderV3>
|
||||
async createProvider(id: string, settings?: unknown): Promise<ProviderV3> {
|
||||
const parsed = this.parseProviderId(id)
|
||||
if (!parsed) {
|
||||
throw new Error(`Provider extension "${id}" not found. Did you forget to register it?`)
|
||||
}
|
||||
|
||||
const { baseId, mode: variantSuffix } = parsed
|
||||
|
||||
const extension = this.get(baseId)
|
||||
if (!extension) {
|
||||
throw new Error(`Provider extension "${baseId}" not found. Did you forget to register it?`)
|
||||
}
|
||||
|
||||
try {
|
||||
return await extension.createProvider(settings, variantSuffix)
|
||||
} catch (error) {
|
||||
throw new ProviderCreationError(
|
||||
`Failed to create provider "${id}"`,
|
||||
id,
|
||||
error instanceof Error ? error : new Error(String(error))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局 Extension Registry 实例
|
||||
* 单例模式,确保整个应用只有一个注册表
|
||||
*/
|
||||
export const extensionRegistry = new ExtensionRegistry()
|
||||
344
packages/aiCore/src/core/providers/core/ProviderExtension.ts
Normal file
344
packages/aiCore/src/core/providers/core/ProviderExtension.ts
Normal file
@@ -0,0 +1,344 @@
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { LRUCache } from 'lru-cache'
|
||||
|
||||
import { deepMergeObjects } from '../../utils'
|
||||
import type { ProviderVariant, ToolFactoryMap } from '../types'
|
||||
|
||||
export type ProviderCreatorFunction<TSettings = any> = (settings?: TSettings) => ProviderV3 | Promise<ProviderV3>
|
||||
|
||||
/**
|
||||
* Provider 模块类型
|
||||
* 动态导入的模块应该包含至少一个创建函数
|
||||
* 允许 default 导出和其他属性
|
||||
*/
|
||||
export type ProviderModule<TSettings = any> = Record<string, any> & {
|
||||
[K: string]: ProviderCreatorFunction<TSettings> | any
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider Extension 配置基础接口
|
||||
*
|
||||
* @typeParam TSettings - Provider 配置类型
|
||||
* @typeParam TProvider - 实际 provider 类型(用于 variants)
|
||||
* @typeParam TName - Provider 名称类型(用于字面量推导)
|
||||
*/
|
||||
interface ProviderExtensionConfigBase<
|
||||
TSettings = any,
|
||||
TProvider extends ProviderV3 = ProviderV3,
|
||||
TName extends string = string
|
||||
> {
|
||||
/** Provider 唯一标识 */
|
||||
name: TName
|
||||
|
||||
/** 别名列表(可选) */
|
||||
aliases?: readonly string[]
|
||||
|
||||
/** 默认配置选项 */
|
||||
defaultOptions?: Partial<TSettings>
|
||||
|
||||
/** 是否支持图像生成 */
|
||||
supportsImageGeneration?: boolean
|
||||
|
||||
/**
|
||||
* Provider 变体配置
|
||||
* 用于注册同一 provider 的不同模式
|
||||
*/
|
||||
variants?: readonly ProviderVariant<TSettings, TProvider>[]
|
||||
|
||||
/**
|
||||
* Tool factory 映射
|
||||
* 声明该 provider 支持的工具能力(如 webSearch)
|
||||
* 工具工厂从 provider 实例的 .tools 属性提取
|
||||
*/
|
||||
toolFactories?: ToolFactoryMap<TProvider>
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider Extension 配置接口 - 使用 create 函数
|
||||
*/
|
||||
interface ProviderExtensionConfigWithCreate<
|
||||
TSettings = any,
|
||||
TProvider extends ProviderV3 = ProviderV3,
|
||||
TName extends string = string
|
||||
> extends ProviderExtensionConfigBase<TSettings, TProvider, TName> {
|
||||
create: ProviderCreatorFunction<TSettings>
|
||||
|
||||
import?: never
|
||||
|
||||
creatorFunctionName?: never
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider Extension 配置接口 - 使用动态导入
|
||||
*/
|
||||
interface ProviderExtensionConfigWithImport<
|
||||
TSettings = any,
|
||||
TProvider extends ProviderV3 = ProviderV3,
|
||||
TName extends string = string
|
||||
> extends ProviderExtensionConfigBase<TSettings, TProvider, TName> {
|
||||
create?: never
|
||||
|
||||
import: () => Promise<ProviderModule<TSettings>>
|
||||
|
||||
creatorFunctionName: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider Extension 配置接口
|
||||
* 使用联合类型确保 create 和 import 互斥
|
||||
*
|
||||
* @typeParam TSettings - Provider 配置类型
|
||||
* @typeParam TProvider - 实际 provider 类型(用于 variants)
|
||||
* @typeParam TName - Provider 名称类型(用于字面量推导)
|
||||
*/
|
||||
export type ProviderExtensionConfig<
|
||||
TSettings = any,
|
||||
TProvider extends ProviderV3 = ProviderV3,
|
||||
TName extends string = string
|
||||
> =
|
||||
| ProviderExtensionConfigWithCreate<TSettings, TProvider, TName>
|
||||
| ProviderExtensionConfigWithImport<TSettings, TProvider, TName>
|
||||
|
||||
/**
|
||||
* Provider Extension 类
|
||||
*
|
||||
* @typeParam TSettings - Provider 配置类型
|
||||
* @typeParam TProvider - 实际 provider 类型(用于 variants)
|
||||
* @typeParam TConfig - 配置对象类型(幻影类型参数,用于自动推导 Provider IDs)
|
||||
*/
|
||||
export class ProviderExtension<
|
||||
TSettings = any,
|
||||
TProvider extends ProviderV3 = ProviderV3,
|
||||
TConfig extends ProviderExtensionConfig<TSettings, TProvider, string> = ProviderExtensionConfig<
|
||||
TSettings,
|
||||
TProvider,
|
||||
string
|
||||
>
|
||||
> {
|
||||
/** Provider 实例缓存 - 按 settings hash 存储,LRU 自动清理 */
|
||||
private instances: LRUCache<string, TProvider>
|
||||
|
||||
/** In-flight promise map - 防止并发创建相同 settings 的 provider */
|
||||
private pendingCreations: Map<string, Promise<TProvider>> = new Map()
|
||||
|
||||
constructor(public readonly config: TConfig) {
|
||||
if (!config.name) {
|
||||
throw new Error('ProviderExtension: name is required')
|
||||
}
|
||||
|
||||
this.instances = new LRUCache<string, TProvider>({
|
||||
max: 10,
|
||||
updateAgeOnGet: true
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 静态工厂方法 - 创建 Provider Extension
|
||||
*/
|
||||
static create<
|
||||
const TConfig extends ProviderExtensionConfig<any, any, string>,
|
||||
TSettings = TConfig extends ProviderExtensionConfig<infer S, any, any> ? S : any,
|
||||
TProvider extends ProviderV3 = TConfig extends ProviderExtensionConfig<any, infer P, any> ? P : ProviderV3
|
||||
>(config: TConfig | (() => TConfig)): ProviderExtension<TSettings, TProvider, TConfig>
|
||||
static create(config: any): ProviderExtension<any, any, any> {
|
||||
const resolvedConfig = typeof config === 'function' ? config() : config
|
||||
return new ProviderExtension(resolvedConfig)
|
||||
}
|
||||
|
||||
/**
|
||||
* Options getter - 只读配置
|
||||
*/
|
||||
get options(): Readonly<Partial<TSettings>> {
|
||||
return Object.freeze({ ...this.config.defaultOptions })
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算 settings 的稳定 hash
|
||||
*/
|
||||
private computeHash(settings?: TSettings, variantSuffix?: string): string {
|
||||
const baseKey = (() => {
|
||||
if (settings === undefined || settings === null) {
|
||||
return 'default'
|
||||
}
|
||||
|
||||
const stableStringify = (obj: any): string => {
|
||||
if (obj === null || obj === undefined) return 'null'
|
||||
if (typeof obj === 'function') return '"[function]"'
|
||||
if (typeof obj !== 'object') return JSON.stringify(obj)
|
||||
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
|
||||
|
||||
const keys = Object.keys(obj).sort()
|
||||
const pairs = keys.map((key) => `${JSON.stringify(key)}:${stableStringify(obj[key])}`)
|
||||
return `{${pairs.join(',')}}`
|
||||
}
|
||||
|
||||
return stableStringify(settings)
|
||||
})()
|
||||
|
||||
return variantSuffix ? `${baseKey}:${variantSuffix}` : baseKey
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Provider 实例
|
||||
* 相同 settings 会复用实例,不同 settings 会创建新实例
|
||||
*/
|
||||
async createProvider(settings?: TSettings, variantSuffix?: string): Promise<TProvider> {
|
||||
if (variantSuffix) {
|
||||
const variant = this.getVariant(variantSuffix)
|
||||
if (!variant) {
|
||||
throw new Error(
|
||||
`ProviderExtension "${this.config.name}": variant "${variantSuffix}" not found. ` +
|
||||
`Available variants: ${this.config.variants?.map((v) => v.suffix).join(', ') || 'none'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const mergedSettings = deepMergeObjects(this.config.defaultOptions || {}, settings || {}) as TSettings
|
||||
|
||||
const hash = this.computeHash(mergedSettings, variantSuffix)
|
||||
|
||||
const cachedInstance = this.instances.get(hash)
|
||||
if (cachedInstance) {
|
||||
return cachedInstance
|
||||
}
|
||||
|
||||
const pending = this.pendingCreations.get(hash)
|
||||
if (pending) {
|
||||
return pending
|
||||
}
|
||||
|
||||
const creationPromise = this._doCreateProvider(mergedSettings, variantSuffix, hash)
|
||||
this.pendingCreations.set(hash, creationPromise)
|
||||
|
||||
try {
|
||||
return await creationPromise
|
||||
} finally {
|
||||
this.pendingCreations.delete(hash)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取基础 provider 实例(无变体转换)
|
||||
* 用于访问 provider 的 .tools 属性
|
||||
*/
|
||||
async getBaseProvider(settings?: TSettings): Promise<TProvider> {
|
||||
return this.createProvider(settings)
|
||||
}
|
||||
|
||||
private async _doCreateProvider(
|
||||
mergedSettings: TSettings,
|
||||
variantSuffix: string | undefined,
|
||||
hash: string
|
||||
): Promise<TProvider> {
|
||||
let baseProvider: ProviderV3
|
||||
|
||||
if (this.config.create) {
|
||||
baseProvider = await Promise.resolve(this.config.create(mergedSettings))
|
||||
} else if (this.config.import && this.config.creatorFunctionName) {
|
||||
const module = await this.config.import()
|
||||
const creatorFn = module[this.config.creatorFunctionName]
|
||||
|
||||
if (!creatorFn || typeof creatorFn !== 'function') {
|
||||
throw new Error(
|
||||
`ProviderExtension "${this.config.name}": creatorFunctionName "${this.config.creatorFunctionName}" not found in imported module`
|
||||
)
|
||||
}
|
||||
|
||||
baseProvider = await Promise.resolve(creatorFn(mergedSettings))
|
||||
} else {
|
||||
throw new Error(`ProviderExtension "${this.config.name}": cannot create provider, invalid configuration`)
|
||||
}
|
||||
|
||||
let finalProvider: TProvider
|
||||
if (variantSuffix) {
|
||||
const variant = this.getVariant(variantSuffix)!
|
||||
if (variant.transform) {
|
||||
const baseHash = this.computeHash(mergedSettings)
|
||||
if (!this.instances.has(baseHash)) {
|
||||
this.instances.set(baseHash, baseProvider as TProvider)
|
||||
}
|
||||
finalProvider = (await Promise.resolve(
|
||||
variant.transform(baseProvider as TProvider, mergedSettings)
|
||||
)) as TProvider
|
||||
} else {
|
||||
finalProvider = baseProvider as TProvider
|
||||
}
|
||||
} else {
|
||||
finalProvider = baseProvider as TProvider
|
||||
}
|
||||
|
||||
this.instances.set(hash, finalProvider)
|
||||
|
||||
return finalProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 配置 provider(链式调用)
|
||||
* 返回一个新的 Extension 实例,不修改原实例
|
||||
*/
|
||||
configure(settings: Partial<TSettings>): ProviderExtension<TSettings, TProvider> {
|
||||
return new ProviderExtension({
|
||||
...this.config,
|
||||
defaultOptions: deepMergeObjects(this.config.defaultOptions || ({} as any), settings)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有 provider IDs(包含变体和别名)
|
||||
*/
|
||||
getProviderIds(): string[] {
|
||||
const ids = [this.config.name, ...(this.config.aliases || [])]
|
||||
|
||||
if (this.config.variants) {
|
||||
for (const variant of this.config.variants) {
|
||||
ids.push(`${this.config.name}-${variant.suffix}`)
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查给定 ID 是否属于此 Extension
|
||||
*/
|
||||
hasProviderId(id: string): boolean {
|
||||
return this.getProviderIds().includes(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取变体配置
|
||||
*/
|
||||
getVariant(suffix: string): ProviderVariant<TSettings, TProvider> | undefined {
|
||||
return this.config.variants?.find((v) => v.suffix === suffix)
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有缓存的 Provider 实例
|
||||
*/
|
||||
clearCache(): void {
|
||||
this.instances.clear()
|
||||
this.pendingCreations.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已缓存的 provider 实例(如果存在)
|
||||
*/
|
||||
getCachedProvider(): TProvider | undefined {
|
||||
for (const [key, value] of this.instances.entries()) {
|
||||
if (!key.includes(':')) return value
|
||||
}
|
||||
for (const [, value] of this.instances.entries()) {
|
||||
return value
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取缓存统计信息
|
||||
*/
|
||||
getCacheStats(): { cachedInstances: number } {
|
||||
return {
|
||||
cachedInstances: this.instances.size
|
||||
}
|
||||
}
|
||||
}
|
||||
312
packages/aiCore/src/core/providers/core/initialization.ts
Normal file
312
packages/aiCore/src/core/providers/core/initialization.ts
Normal file
@@ -0,0 +1,312 @@
|
||||
/**
|
||||
* Provider 初始化器
|
||||
* 负责根据配置创建 providers 并注册到全局管理器
|
||||
* 使用新的 Extension 系统
|
||||
*/
|
||||
|
||||
import type { AnthropicProvider, AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import type { AzureOpenAIProvider, AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import type { DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import type { GoogleGenerativeAIProvider, GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import type { OpenAIProvider, OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import type { OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { XaiProvider, XaiProviderSettings } from '@ai-sdk/xai'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import type { CherryInProvider, CherryInProviderSettings } from '@cherrystudio/ai-sdk-provider'
|
||||
import { createCherryIn } from '@cherrystudio/ai-sdk-provider'
|
||||
import type { OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
|
||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider'
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import type { OpenRouterSearchConfig } from '../../plugins/built-in/webSearchPlugin'
|
||||
import type { ExtensionConfigToIdResolutionMap, ExtractExtensionIds, UnionToIntersection } from '../types'
|
||||
import { extensionRegistry } from './ExtensionRegistry'
|
||||
import type { ProviderExtensionConfig } from './ProviderExtension'
|
||||
import { ProviderExtension } from './ProviderExtension'
|
||||
|
||||
// ==================== Core Extensions ====================
|
||||
|
||||
const AnthropicExtension = ProviderExtension.create({
|
||||
name: 'anthropic',
|
||||
aliases: ['claude'] as const,
|
||||
supportsImageGeneration: false,
|
||||
create: createAnthropic,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider) => (config: NonNullable<Parameters<AnthropicProvider['tools']['webSearch_20250305']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearch_20250305(config) }
|
||||
}),
|
||||
urlContext:
|
||||
(provider) => (config: NonNullable<Parameters<AnthropicProvider['tools']['webFetch_20260209']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webFetch_20260209(config) }
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<AnthropicProviderSettings, AnthropicProvider, 'anthropic'>)
|
||||
|
||||
/**
|
||||
* Azure Extension
|
||||
*/
|
||||
const AzureExtension = ProviderExtension.create({
|
||||
name: 'azure',
|
||||
aliases: ['azure-openai'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: (settings) => {
|
||||
const provider = createAzure(settings)
|
||||
// Default to chat mode (AI SDK defaults to responses API)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: AzureOpenAIProvider) =>
|
||||
(config: NonNullable<Parameters<AzureOpenAIProvider['tools']['webSearchPreview']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearchPreview(config) }
|
||||
})
|
||||
},
|
||||
variants: [
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'Azure OpenAI Responses',
|
||||
transform: (provider) =>
|
||||
customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
// Azure 上的 Claude 模型走 Anthropic SDK
|
||||
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
|
||||
{
|
||||
suffix: 'anthropic',
|
||||
name: 'Azure Anthropic',
|
||||
transform: (_provider, settings) =>
|
||||
createAnthropic({
|
||||
baseURL: (settings?.baseURL ?? '') + '/anthropic/v1',
|
||||
apiKey: settings?.apiKey ?? '',
|
||||
headers: settings?.headers
|
||||
})
|
||||
}
|
||||
] as const
|
||||
} as const satisfies ProviderExtensionConfig<AzureOpenAIProviderSettings, AzureOpenAIProvider, 'azure'>)
|
||||
|
||||
const CherryInExtension = ProviderExtension.create({
|
||||
name: 'cherryin',
|
||||
supportsImageGeneration: true,
|
||||
create: createCherryIn,
|
||||
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'CherryIN Chat',
|
||||
transform: (provider) =>
|
||||
customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
}
|
||||
] as const
|
||||
} as const satisfies ProviderExtensionConfig<CherryInProviderSettings, CherryInProvider, 'cherryin'>)
|
||||
|
||||
const DeepSeekExtension = ProviderExtension.create({
|
||||
name: 'deepseek',
|
||||
supportsImageGeneration: false,
|
||||
create: createDeepSeek
|
||||
} as const satisfies ProviderExtensionConfig<DeepSeekProviderSettings, ProviderV3, 'deepseek'>)
|
||||
|
||||
const GoogleExtension = ProviderExtension.create({
|
||||
name: 'google',
|
||||
aliases: ['google-ai', 'gemini', 'google-gemini'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createGoogleGenerativeAI,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: GoogleGenerativeAIProvider) =>
|
||||
(config: NonNullable<Parameters<GoogleGenerativeAIProvider['tools']['googleSearch']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.googleSearch(config) }
|
||||
}),
|
||||
urlContext: (provider) => (config) => ({
|
||||
tools: {
|
||||
urlContext: provider.tools.urlContext(config)
|
||||
}
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<GoogleGenerativeAIProviderSettings, GoogleGenerativeAIProvider, 'google'>)
|
||||
|
||||
const OpenAICompatibleExtension = ProviderExtension.create({
|
||||
name: 'openai-compatible',
|
||||
supportsImageGeneration: true,
|
||||
create: (settings) => {
|
||||
if (!settings) {
|
||||
throw new Error('OpenAI Compatible provider requires settings')
|
||||
}
|
||||
return createOpenAICompatible(settings)
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<OpenAICompatibleProviderSettings, ProviderV3, 'openai-compatible'>)
|
||||
|
||||
const OpenAIExtension = ProviderExtension.create({
|
||||
name: 'openai',
|
||||
aliases: ['openai-response'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createOpenAI,
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: OpenAIProvider) => (config: NonNullable<Parameters<OpenAIProvider['tools']['webSearch']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearch(config) }
|
||||
})
|
||||
},
|
||||
|
||||
variants: [
|
||||
{
|
||||
suffix: 'chat',
|
||||
name: 'OpenAI Chat',
|
||||
resolveModel: (provider: OpenAIProvider, modelId: string) => provider.chat(modelId),
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: OpenAIProvider) =>
|
||||
(config: NonNullable<Parameters<OpenAIProvider['tools']['webSearchPreview']>[0]>) => ({
|
||||
tools: { webSearch: provider.tools.webSearchPreview(config) }
|
||||
})
|
||||
}
|
||||
}
|
||||
] as const
|
||||
} as const satisfies ProviderExtensionConfig<OpenAIProviderSettings, OpenAIProvider, 'openai'>)
|
||||
|
||||
const OpenRouterExtension = ProviderExtension.create({
|
||||
name: 'openrouter',
|
||||
// TODO: 实现注册后修改拓展配置
|
||||
aliases: ['tokenflux'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createOpenRouter,
|
||||
toolFactories: {
|
||||
webSearch: () => (config: OpenRouterSearchConfig) => ({
|
||||
providerOptions: { openrouter: config }
|
||||
})
|
||||
}
|
||||
} as const satisfies ProviderExtensionConfig<OpenRouterProviderSettings, ProviderV3, 'openrouter'>)
|
||||
|
||||
const XaiExtension = ProviderExtension.create({
|
||||
name: 'xai',
|
||||
aliases: ['grok'] as const,
|
||||
supportsImageGeneration: true,
|
||||
create: createXai,
|
||||
variants: [
|
||||
{
|
||||
suffix: 'responses',
|
||||
name: 'xAI Responses',
|
||||
resolveModel: (provider: XaiProvider, modelId: string) => provider.responses(modelId),
|
||||
toolFactories: {
|
||||
webSearch:
|
||||
(provider: XaiProvider) =>
|
||||
(config: {
|
||||
webSearch?: Parameters<XaiProvider['tools']['webSearch']>[0]
|
||||
xSearch?: Parameters<XaiProvider['tools']['xSearch']>[0]
|
||||
}) => ({
|
||||
tools: {
|
||||
webSearch: provider.tools.webSearch(config?.webSearch ?? {}),
|
||||
xSearch: provider.tools.xSearch(config?.xSearch ?? {})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
] as const
|
||||
} as const satisfies ProviderExtensionConfig<XaiProviderSettings, XaiProvider, 'xai'>)
|
||||
|
||||
/**
|
||||
* 核心 provider extensions 列表
|
||||
*/
|
||||
export const coreExtensions = [
|
||||
OpenAIExtension,
|
||||
AnthropicExtension,
|
||||
AzureExtension,
|
||||
GoogleExtension,
|
||||
XaiExtension,
|
||||
DeepSeekExtension,
|
||||
OpenRouterExtension,
|
||||
OpenAICompatibleExtension,
|
||||
CherryInExtension
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 核心 Provider IDs 类型
|
||||
* 从 coreExtensions 数组自动提取所有 provider IDs(包括 aliases 和 variants)
|
||||
*
|
||||
*/
|
||||
export type CoreProviderId = ExtractExtensionIds<(typeof coreExtensions)[number]>
|
||||
|
||||
type ExtensionConfigs = (typeof coreExtensions)[number]['config']
|
||||
|
||||
type ProviderIdsMap = UnionToIntersection<ExtensionConfigToIdResolutionMap<ExtensionConfigs>>
|
||||
|
||||
export const registeredProviderIds: ProviderIdsMap = (() => {
|
||||
const map = {} as ProviderIdsMap
|
||||
coreExtensions.forEach((ext) => {
|
||||
const config = ext.config as ProviderExtensionConfig<any, any, CoreProviderId>
|
||||
const name = config.name
|
||||
;(map as Record<string, CoreProviderId>)[name] = name
|
||||
|
||||
if (config.aliases) {
|
||||
config.aliases.forEach((alias) => {
|
||||
;(map as Record<string, CoreProviderId>)[alias] = name
|
||||
})
|
||||
}
|
||||
|
||||
if (config.variants) {
|
||||
config.variants.forEach((variant) => {
|
||||
;(map as Record<string, CoreProviderId>)[`${name}-${variant.suffix}`] = name
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return map
|
||||
})()
|
||||
|
||||
// ==================== 初始化 Extension Registry ====================
|
||||
|
||||
/**
|
||||
* 注册所有通用 extensions 到全局 registry
|
||||
* 在模块加载时自动执行
|
||||
*
|
||||
* 注意:只注册通用的 provider extensions(OpenAI, Anthropic, Google 等)
|
||||
* 项目特定的 extensions 应该在应用层单独注册
|
||||
*/
|
||||
// register() is idempotent — safe to call on HMR / re-import
|
||||
extensionRegistry.registerAll(coreExtensions)
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
*/
|
||||
class ProviderInitializationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderInitializationError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有对应的 Provider Extension
|
||||
*/
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return extensionRegistry.has(providerId)
|
||||
}
|
||||
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export { ProviderInitializationError }
|
||||
@@ -1,3 +1,10 @@
|
||||
/**
|
||||
* Provider 工具函数和错误类
|
||||
* 合并自 utils.ts 和 errors.ts
|
||||
*/
|
||||
|
||||
// ==================== 私钥格式化工具 ====================
|
||||
|
||||
/**
|
||||
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||
*/
|
||||
@@ -84,3 +91,20 @@ function reconstructPemKey(key: string): string {
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
|
||||
// ==================== 错误类 ====================
|
||||
|
||||
/**
|
||||
* Provider 创建错误
|
||||
* 当创建 provider 实例失败时抛出
|
||||
*/
|
||||
export class ProviderCreationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public cause: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderCreationError'
|
||||
}
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
/**
|
||||
* AI Provider 配置工厂
|
||||
* 提供类型安全的 Provider 配置构建器
|
||||
*/
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||
|
||||
/**
|
||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||
*/
|
||||
export interface BaseProviderConfig {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
fetch?: typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||
*/
|
||||
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||
|
||||
type ConfigHandler<T extends ProviderId> = (
|
||||
builder: ProviderConfigBuilder<T>,
|
||||
provider: CompleteProviderConfig<T>
|
||||
) => void
|
||||
|
||||
const configHandlers: {
|
||||
[K in ProviderId]?: ConfigHandler<K>
|
||||
} = {
|
||||
azure: (builder, provider) => {
|
||||
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||
azureBuilder.withAzureConfig({
|
||||
apiVersion: azureProvider.apiVersion,
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||
|
||||
constructor(private providerId: T) {}
|
||||
|
||||
/**
|
||||
* 设置 API Key
|
||||
*/
|
||||
withApiKey(apiKey: string): this
|
||||
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||
withApiKey(apiKey: string, options?: any): this {
|
||||
this.config.apiKey = apiKey
|
||||
|
||||
// 类型安全的 OpenAI 特定配置
|
||||
if (this.providerId === 'openai' && options) {
|
||||
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||
if (options.organization) {
|
||||
openaiConfig.organization = options.organization
|
||||
}
|
||||
if (options.project) {
|
||||
openaiConfig.project = options.project
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置基础 URL
|
||||
*/
|
||||
withBaseURL(baseURL: string) {
|
||||
this.config.baseURL = baseURL
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求配置
|
||||
*/
|
||||
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||
if (options.headers) {
|
||||
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||
}
|
||||
if (options.fetch) {
|
||||
this.config.fetch = options.fetch
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 特定配置
|
||||
*/
|
||||
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||
withAzureConfig(options: any): any {
|
||||
if (this.providerId === 'azure') {
|
||||
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||
if (options.apiVersion) {
|
||||
azureConfig.apiVersion = options.apiVersion
|
||||
}
|
||||
if (options.resourceName) {
|
||||
azureConfig.resourceName = options.resourceName
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
withCustomParams(params: Record<string, any>) {
|
||||
Object.assign(this.config, params)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终配置
|
||||
*/
|
||||
build(): ProviderSettingsMap[T] {
|
||||
return this.config as ProviderSettingsMap[T]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 配置工厂
|
||||
* 提供便捷的配置创建方法
|
||||
*/
|
||||
export class ProviderConfigFactory {
|
||||
/**
|
||||
* 创建配置构建器
|
||||
*/
|
||||
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||
return new ProviderConfigBuilder(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||
*/
|
||||
static fromProvider<T extends ProviderId>(
|
||||
providerId: T,
|
||||
provider: CompleteProviderConfig<T>,
|
||||
options?: {
|
||||
headers?: Record<string, string>
|
||||
[key: string]: any
|
||||
}
|
||||
): ProviderSettingsMap[T] {
|
||||
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||
|
||||
// 设置基本配置
|
||||
if (provider.apiKey) {
|
||||
builder.withApiKey(provider.apiKey)
|
||||
}
|
||||
|
||||
if (provider.baseURL) {
|
||||
builder.withBaseURL(provider.baseURL)
|
||||
}
|
||||
|
||||
// 设置请求配置
|
||||
if (options?.headers) {
|
||||
builder.withRequestConfig({
|
||||
headers: options.headers
|
||||
})
|
||||
}
|
||||
|
||||
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||
const handler = configHandlers[providerId]
|
||||
if (handler) {
|
||||
handler(builder, provider)
|
||||
}
|
||||
|
||||
// 添加其他自定义参数
|
||||
if (options) {
|
||||
const customOptions = { ...options }
|
||||
delete customOptions.headers // 已经处理过了
|
||||
if (Object.keys(customOptions).length > 0) {
|
||||
builder.withCustomParams(customOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 OpenAI 配置
|
||||
*/
|
||||
static createOpenAI(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
organization?: string
|
||||
project?: string
|
||||
}
|
||||
) {
|
||||
const builder = this.builder('openai')
|
||||
|
||||
// 使用类型安全的重载
|
||||
if (options?.organization || options?.project) {
|
||||
builder.withApiKey(apiKey, {
|
||||
organization: options.organization,
|
||||
project: options.project
|
||||
})
|
||||
} else {
|
||||
builder.withApiKey(apiKey)
|
||||
}
|
||||
|
||||
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Anthropic 配置
|
||||
*/
|
||||
static createAnthropic(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('anthropic')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Azure OpenAI 配置
|
||||
*/
|
||||
static createAzureOpenAI(
|
||||
apiKey: string,
|
||||
options: {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options.baseURL)
|
||||
.withAzureConfig({
|
||||
apiVersion: options.apiVersion,
|
||||
resourceName: options.resourceName
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Google 配置
|
||||
*/
|
||||
static createGoogle(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
projectId?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Vertex AI 配置
|
||||
*/
|
||||
static createVertexAI() {
|
||||
// credentials: {
|
||||
// clientEmail: string
|
||||
// privateKey: string
|
||||
// },
|
||||
// options?: {
|
||||
// project?: string
|
||||
// location?: string
|
||||
// }
|
||||
// return this.builder('google-vertex')
|
||||
// .withGoogleCredentials(credentials)
|
||||
// .withGoogleVertexConfig({
|
||||
// project: options?.project,
|
||||
// location: options?.location
|
||||
// })
|
||||
// .build()
|
||||
}
|
||||
|
||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷的配置创建函数
|
||||
*/
|
||||
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||
@@ -4,81 +4,50 @@
|
||||
|
||||
// ==================== 核心管理器 ====================
|
||||
|
||||
// Provider 注册表管理器
|
||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||
|
||||
// Provider 核心功能
|
||||
export {
|
||||
// 状态管理
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getImageModel,
|
||||
// 工具函数
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// 工具函数
|
||||
hasProviderConfig,
|
||||
// 别名支持
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
// 全局访问
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
// 统一Provider系统
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from './registry'
|
||||
export { coreExtensions, hasProviderConfig } from './core/initialization'
|
||||
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 基础Provider数据源
|
||||
export { baseProviderIds, baseProviders, isBaseProvider } from './schemas'
|
||||
// 类型定义
|
||||
export type { AiSdkModel, ProviderError } from './types'
|
||||
|
||||
// 类型定义和Schema
|
||||
// 类型提取工具
|
||||
export type {
|
||||
BaseProviderId,
|
||||
CustomProviderId,
|
||||
DynamicProviderRegistration,
|
||||
ProviderConfig,
|
||||
ProviderId
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
AiSdkModel,
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
ProviderSettingsMap,
|
||||
ProviderTypeRegistrar
|
||||
CoreProviderSettingsMap,
|
||||
ExtensionConfigToIdResolutionMap,
|
||||
ExtensionToSettingsMap,
|
||||
ExtractProviderIds,
|
||||
StringKeys,
|
||||
UnionToIntersection
|
||||
} from './types'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
// Provider配置工厂
|
||||
// 工具函数和错误类
|
||||
export { formatPrivateKey, ProviderCreationError } from './core/utils'
|
||||
|
||||
// ==================== Provider Extension 系统 ====================
|
||||
|
||||
// Extension 核心类和类型
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './factory'
|
||||
type ProviderCreatorFunction,
|
||||
ProviderExtension,
|
||||
type ProviderExtensionConfig,
|
||||
type ProviderModule
|
||||
} from './core/ProviderExtension'
|
||||
|
||||
// 工具函数
|
||||
export { formatPrivateKey } from './utils'
|
||||
|
||||
// ==================== 扩展功能 ====================
|
||||
|
||||
// Hub Provider 功能
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||
// Extension Registry
|
||||
export { ExtensionRegistry, extensionRegistry } from './core/ExtensionRegistry'
|
||||
export type { ProviderVariant } from './types'
|
||||
export type {
|
||||
ExtractToolConfig,
|
||||
ExtractToolConfigMap,
|
||||
ProviderId,
|
||||
RegisteredProviderId,
|
||||
ToolCapability,
|
||||
ToolFactory,
|
||||
ToolFactoryMap,
|
||||
ToolFactoryPatch,
|
||||
WebSearchToolConfigMap
|
||||
} from './types'
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
/**
|
||||
* Provider 初始化器
|
||||
* 负责根据配置创建 providers 并注册到全局管理器
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders, type ProviderConfig } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
*/
|
||||
class ProviderInitializationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderInitializationError'
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 全局管理器导出 ====================
|
||||
|
||||
export { globalRegistryManagement as providerRegistry }
|
||||
|
||||
// ==================== 便捷访问方法 ====================
|
||||
|
||||
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
|
||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.embeddingModel(id as any)
|
||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
export function getSupportedProviders(): Array<{
|
||||
id: string
|
||||
name: string
|
||||
}> {
|
||||
return baseProviders.map((provider) => ({
|
||||
id: provider.id,
|
||||
name: provider.name
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已初始化的 providers
|
||||
*/
|
||||
export function getInitializedProviders(): string[] {
|
||||
return globalRegistryManagement.getRegisteredProviders()
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有任何已初始化的 providers
|
||||
*/
|
||||
export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 统一Provider配置系统 ====================
|
||||
|
||||
// 全局Provider配置存储
|
||||
const providerConfigs = new Map<string, ProviderConfig>()
|
||||
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||
|
||||
/**
|
||||
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||
*/
|
||||
function initializeBuiltInConfigs(): void {
|
||||
baseProviders.forEach((provider) => {
|
||||
const config: ProviderConfig = {
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||
}
|
||||
providerConfigs.set(provider.id, config)
|
||||
})
|
||||
}
|
||||
|
||||
// 启动时自动注册内置配置
|
||||
initializeBuiltInConfigs()
|
||||
|
||||
/**
|
||||
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||
*/
|
||||
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config || !config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与已有配置冲突(包括内置配置)
|
||||
if (providerConfigs.has(config.id)) {
|
||||
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||
}
|
||||
|
||||
// 存储配置(内置和用户配置统一处理)
|
||||
providerConfigs.set(config.id, config)
|
||||
|
||||
// 处理别名
|
||||
if (config.aliases && config.aliases.length > 0) {
|
||||
config.aliases.forEach((alias) => {
|
||||
if (providerConfigAliases.has(alias)) {
|
||||
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||
}
|
||||
providerConfigAliases.set(alias, config.id)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register ProviderConfig:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||
*/
|
||||
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||
// 支持通过别名查找配置
|
||||
const config = getProviderConfigByAlias(providerId)
|
||||
|
||||
if (!config) {
|
||||
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
let creator: (options: any) => any
|
||||
|
||||
if (config.creator) {
|
||||
// 方式1: 直接执行 creator
|
||||
creator = config.creator
|
||||
} else if (config.import && config.creatorFunctionName) {
|
||||
// 方式2: 动态导入并执行
|
||||
const module = await config.import()
|
||||
creator = (module as any)[config.creatorFunctionName]
|
||||
|
||||
if (!creator || typeof creator !== 'function') {
|
||||
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('No valid creator method provided in ProviderConfig')
|
||||
}
|
||||
|
||||
// 使用真实配置创建provider实例
|
||||
return creator(options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create provider "${providerId}":`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤3: 注册Provider到全局管理器
|
||||
*/
|
||||
export function registerProvider(providerId: string, provider: any): boolean {
|
||||
try {
|
||||
const config = providerConfigs.get(providerId)
|
||||
if (!config) {
|
||||
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取aliases配置
|
||||
const aliases = config.aliases
|
||||
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||
} else if (providerId === 'azure') {
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
// Keep the resolver's azure-chat fallback aligned with the local provider creator.
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷函数: 一次性完成创建+注册
|
||||
*/
|
||||
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||
try {
|
||||
// 步骤2: 创建provider
|
||||
const provider = await createProvider(providerId, options)
|
||||
|
||||
// 步骤3: 注册到全局管理器
|
||||
return registerProvider(providerId, provider)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册Provider配置
|
||||
*/
|
||||
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerProviderConfig(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return providerConfigs.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||
const realId = resolveProviderConfigId(aliasOrId)
|
||||
return providerConfigs.has(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有Provider配置
|
||||
*/
|
||||
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||
return Array.from(providerConfigs.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||
return providerConfigs.get(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||
// 先检查是否为别名,如果是则解析为真实ID
|
||||
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
return providerConfigs.get(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的ProviderConfig ID(去别名化)
|
||||
*/
|
||||
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为ProviderConfig别名
|
||||
*/
|
||||
export function isProviderConfigAlias(id: string): boolean {
|
||||
return providerConfigAliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有ProviderConfig别名映射关系
|
||||
*/
|
||||
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
providerConfigAliases.forEach((realId, alias) => {
|
||||
result[alias] = realId
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有Provider配置和已注册的providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
providerConfigs.clear()
|
||||
providerConfigAliases.clear() // 清理别名映射
|
||||
globalRegistryManagement.clear()
|
||||
// 重新初始化内置配置
|
||||
initializeBuiltInConfigs()
|
||||
}
|
||||
|
||||
export function clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export { ProviderInitializationError }
|
||||
@@ -1,237 +0,0 @@
|
||||
/**
|
||||
* Provider Config 定义
|
||||
*/
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { type CherryInProviderSettings, createCherryIn } from '@cherrystudio/ai-sdk-provider'
|
||||
import { createOpenRouter, type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
|
||||
import { customProvider, wrapProvider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
*/
|
||||
export const baseProviderIds = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'deepseek',
|
||||
'openrouter',
|
||||
'cherryin',
|
||||
'cherryin-chat'
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 基础 Provider ID Schema
|
||||
*/
|
||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
|
||||
/**
|
||||
* 基础 Provider ID
|
||||
*/
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
|
||||
export const isBaseProvider = (id: ProviderId): id is BaseProviderId => {
|
||||
return baseProviderIdSchema.safeParse(id).success
|
||||
}
|
||||
|
||||
type BaseProvider = {
|
||||
id: BaseProviderId
|
||||
name: string
|
||||
creator: (options: any) => ProviderV3
|
||||
supportsImageGeneration: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础 Providers 定义
|
||||
* 作为唯一数据源,避免重复维护
|
||||
*/
|
||||
export const baseProviders = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
creator: createOpenAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-chat',
|
||||
name: 'OpenAI Chat',
|
||||
creator: (options: OpenAIProviderSettings) => {
|
||||
const provider = createOpenAI(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-compatible',
|
||||
name: 'OpenAI Compatible',
|
||||
creator: createOpenAICompatible,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
creator: createAnthropic,
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google Generative AI',
|
||||
creator: createGoogleGenerativeAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
creator: (options: Parameters<typeof createXai>[0]) => {
|
||||
const provider = createXai(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure',
|
||||
name: 'Azure OpenAI',
|
||||
creator: (options: AzureOpenAIProviderSettings) => {
|
||||
const provider = createAzure(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
// Cherry's "azure" path is the chat/deployment-based variant.
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure-responses',
|
||||
name: 'Azure OpenAI Responses',
|
||||
creator: (options: AzureOpenAIProviderSettings) => {
|
||||
const provider = createAzure(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
creator: createDeepSeek,
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
creator: (options?: OpenRouterProviderSettings) => {
|
||||
const provider = createOpenRouter(options)
|
||||
return wrapProvider({ provider, languageModelMiddleware: [] })
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'cherryin',
|
||||
name: 'CherryIN',
|
||||
creator: createCherryIn,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'cherryin-chat',
|
||||
name: 'CherryIN Chat',
|
||||
creator: (options: CherryInProviderSettings) => {
|
||||
const provider = createCherryIn(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
] as const satisfies BaseProvider[]
|
||||
|
||||
/**
|
||||
* 用户自定义 Provider ID Schema
|
||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||
*/
|
||||
export const customProviderIdSchema = z
|
||||
.string()
|
||||
.min(1)
|
||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID Schema - 支持基础和自定义
|
||||
*/
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||
|
||||
/**
|
||||
* Provider 配置 Schema
|
||||
* 用于Provider的配置验证
|
||||
*/
|
||||
export const providerConfigSchema = z
|
||||
.object({
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z
|
||||
.function({
|
||||
input: z.any(),
|
||||
output: z.any()
|
||||
})
|
||||
.optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional(),
|
||||
aliases: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||
|
||||
/**
|
||||
* Provider 配置类型
|
||||
*/
|
||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||
|
||||
/**
|
||||
* 兼容性类型别名
|
||||
* @deprecated 使用 ProviderConfig 替代
|
||||
*/
|
||||
export type DynamicProviderRegistration = ProviderConfig
|
||||
@@ -1,101 +0,0 @@
|
||||
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
import type {
|
||||
EmbeddingModel,
|
||||
EmbeddingModelUsage,
|
||||
ImageModel,
|
||||
ImageModelUsage,
|
||||
LanguageModel,
|
||||
LanguageModelUsage,
|
||||
SpeechModel,
|
||||
TranscriptionModel
|
||||
} from 'ai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
}
|
||||
|
||||
// 动态扩展的provider类型注册表
|
||||
export interface DynamicProviderRegistry {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 合并基础和动态provider类型
|
||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||
|
||||
// 错误类型
|
||||
export class ProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public code?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
// 重新导出所有类型供外部使用
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
XaiProviderSettings
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
|
||||
export type AiSdkProvider = ProviderV2 | ProviderV3
|
||||
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
|
||||
|
||||
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||
|
||||
export const METHOD_MAP = {
|
||||
text: 'languageModel',
|
||||
image: 'imageModel',
|
||||
embedding: 'embeddingModel',
|
||||
transcription: 'transcriptionModel',
|
||||
speech: 'speechModel'
|
||||
} as const satisfies Record<AiSdkModelType, keyof ProviderV3>
|
||||
|
||||
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV3>
|
||||
|
||||
export type AiSdkModelReturnMap = {
|
||||
text: LanguageModel
|
||||
image: ImageModel
|
||||
embedding: EmbeddingModel
|
||||
transcription: TranscriptionModel
|
||||
speech: SpeechModel
|
||||
}
|
||||
|
||||
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||
|
||||
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||
222
packages/aiCore/src/core/providers/types/index.ts
Normal file
222
packages/aiCore/src/core/providers/types/index.ts
Normal file
@@ -0,0 +1,222 @@
|
||||
import type { ProviderV2, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type {
|
||||
EmbeddingModel,
|
||||
EmbeddingModelUsage,
|
||||
ImageModel,
|
||||
ImageModelUsage,
|
||||
LanguageModel,
|
||||
LanguageModelUsage,
|
||||
SpeechModel,
|
||||
TranscriptionModel
|
||||
} from 'ai'
|
||||
|
||||
import type { coreExtensions } from '../core/initialization'
|
||||
import type { ProviderExtension } from '../core/ProviderExtension'
|
||||
import type { ToolFactoryMap } from './toolFactory'
|
||||
|
||||
// ============================================================================
|
||||
// Type Utilities
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 提取对象类型中的字符串键
|
||||
* @example StringKeys<{ foo: 1, 0: 2 }> = 'foo'
|
||||
*/
|
||||
export type StringKeys<T> = Extract<keyof T, string>
|
||||
|
||||
/** 从 coreExtensions 自动提取的 Provider ID literal union */
|
||||
export type RegisteredProviderId = StringKeys<CoreProviderSettingsMap>
|
||||
|
||||
/** 允许已注册 ID(有自动补全)和任意字符串(动态 provider) */
|
||||
export type ProviderId = RegisteredProviderId | (string & {})
|
||||
|
||||
// 错误类型
|
||||
export class ProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public code?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel | TranscriptionModel | SpeechModel
|
||||
export type AiSdkProvider = ProviderV2 | ProviderV3
|
||||
export type AiSdkUsage = LanguageModelUsage | ImageModelUsage | EmbeddingModelUsage
|
||||
|
||||
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||
|
||||
const METHOD_MAP = {
|
||||
text: 'languageModel',
|
||||
image: 'imageModel',
|
||||
embedding: 'embeddingModel',
|
||||
transcription: 'transcriptionModel',
|
||||
speech: 'speechModel'
|
||||
} as const satisfies Record<AiSdkModelType, keyof ProviderV3>
|
||||
|
||||
type AiSdkModelReturnMap = {
|
||||
text: LanguageModel
|
||||
image: ImageModel
|
||||
embedding: EmbeddingModel
|
||||
transcription: TranscriptionModel
|
||||
speech: SpeechModel
|
||||
}
|
||||
|
||||
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||
|
||||
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||
|
||||
// ============================================================================
|
||||
// Provider Extension 类型定义
|
||||
// ============================================================================
|
||||
|
||||
/** Provider 变体配置 */
|
||||
export interface ProviderVariant<TSettings = any, TProvider extends ProviderV3 = ProviderV3> {
|
||||
suffix: string
|
||||
name: string
|
||||
|
||||
/** 类型安全的模型解析:provider.responses(modelId) / provider.chat(modelId) */
|
||||
resolveModel?: (provider: TProvider, modelId: string) => LanguageModel
|
||||
|
||||
/** 替换整个 provider(如 azure-anthropic),简单方法切换用 resolveModel */
|
||||
transform?: (baseProvider: TProvider, settings?: TSettings) => ProviderV3
|
||||
|
||||
toolFactories?: ToolFactoryMap<TProvider>
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider ID Type Extraction Utilities
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract all Provider IDs from an extension config
|
||||
* 保留字面量类型,避免被推断为 string
|
||||
*/
|
||||
export type ExtractProviderIds<TConfig> = TConfig extends { name: infer TName }
|
||||
? TName extends string
|
||||
?
|
||||
| TName
|
||||
| (TConfig extends { aliases: infer TAliases }
|
||||
? TAliases extends readonly string[]
|
||||
? TAliases[number]
|
||||
: never
|
||||
: never)
|
||||
| (TConfig extends { variants: infer TVariants }
|
||||
? TVariants extends readonly any[]
|
||||
? TVariants[number] extends { suffix: infer TSuffix }
|
||||
? TSuffix extends string
|
||||
? `${TName}-${TSuffix}`
|
||||
: never
|
||||
: never
|
||||
: never
|
||||
: never)
|
||||
: never
|
||||
: never
|
||||
|
||||
/**
|
||||
* Extract Provider IDs from a ProviderExtension instance
|
||||
*/
|
||||
export type ExtractExtensionIds<T> = T extends { config: infer TConfig } ? ExtractProviderIds<TConfig> : never
|
||||
|
||||
/**
|
||||
* Extract Settings type from a ProviderExtension instance
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* type Settings = ExtractExtensionSettings<typeof OpenAIExtension>
|
||||
* // => OpenAIProviderSettings
|
||||
* ```
|
||||
*/
|
||||
export type ExtractExtensionSettings<T> = T extends ProviderExtension<infer TSettings, any, any> ? TSettings : never
|
||||
|
||||
/**
|
||||
* Map all Provider IDs from an Extension to its Settings type
|
||||
*/
|
||||
export type ExtensionToSettingsMap<T> = T extends ProviderExtension<infer TSettings, any, infer TConfig>
|
||||
? { [K in ExtractProviderIds<TConfig>]: TSettings }
|
||||
: never
|
||||
|
||||
// ============================================================================
|
||||
// Provider Settings Map - Auto-extracted from Extensions
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Core Provider Settings Map
|
||||
*/
|
||||
export type CoreProviderSettingsMap = UnionToIntersection<ExtensionToSettingsMap<(typeof coreExtensions)[number]>>
|
||||
|
||||
// 辅助类型:提取所有变体 ID
|
||||
type ExtractVariantIds<TConfig, TName extends string> = TConfig extends {
|
||||
variants: readonly { suffix: infer TSuffix extends string }[]
|
||||
}
|
||||
? `${TName}-${TSuffix}`
|
||||
: never
|
||||
|
||||
export type ExtensionConfigToIdResolutionMap<TConfig> = TConfig extends { name: infer TName extends string }
|
||||
? {
|
||||
readonly [K in
|
||||
| TName
|
||||
| (TConfig extends { aliases: readonly (infer TAlias extends string)[] } ? TAlias : never)
|
||||
| ExtractVariantIds<TConfig, TName>]: K extends ExtractVariantIds<TConfig, TName>
|
||||
? K // 变体 → 自身
|
||||
: TName // 基础名和别名 → TName
|
||||
}
|
||||
: never
|
||||
|
||||
/**
|
||||
* Provider IDs Map Type with Literal Type Inference
|
||||
*/
|
||||
export type UnionToIntersection<U> = (U extends any ? (x: U) => void : never) extends (x: infer I) => void ? I : never
|
||||
|
||||
export type { ToolCapability, ToolFactory, ToolFactoryMap, ToolFactoryPatch } from './toolFactory'
|
||||
|
||||
// ============================================================================
|
||||
// Tool Config Type Extraction (from extension declarations via as const)
|
||||
// ============================================================================
|
||||
|
||||
/** Extract a capability's config type from an extension's toolFactories */
|
||||
export type ExtractToolConfig<TExt, K extends string> = TExt extends {
|
||||
config: { toolFactories?: { [P in K]?: (provider: any) => (config: infer C) => any } }
|
||||
}
|
||||
? C
|
||||
: never
|
||||
|
||||
/** Extract config from variant-level toolFactories (e.g., openai-chat) */
|
||||
type ExtractVariantToolConfig<TExt, K extends string> = TExt extends {
|
||||
config: {
|
||||
name: infer TName extends string
|
||||
variants?: readonly (infer V)[]
|
||||
}
|
||||
}
|
||||
? V extends {
|
||||
suffix: infer TSuffix extends string
|
||||
toolFactories?: { [P in K]?: (provider: any) => (config: infer C) => any }
|
||||
}
|
||||
? { id: `${TName}-${TSuffix}`; config: C }
|
||||
: never
|
||||
: never
|
||||
|
||||
/** Extract { [providerId]: ConfigType } map from all extensions for a capability */
|
||||
export type ExtractToolConfigMap<TExtUnion, K extends string> = UnionToIntersection<
|
||||
| (TExtUnion extends any
|
||||
? ExtractToolConfig<TExtUnion, K> extends never
|
||||
? never
|
||||
: TExtUnion extends { config: { name: infer TName extends string } }
|
||||
? { [P in TName]?: ExtractToolConfig<TExtUnion, K> }
|
||||
: never
|
||||
: never)
|
||||
// Variant configs: name-suffix → config
|
||||
| (TExtUnion extends any
|
||||
? ExtractVariantToolConfig<TExtUnion, K> extends never
|
||||
? never
|
||||
: ExtractVariantToolConfig<TExtUnion, K> extends { id: infer TId extends string; config: infer C }
|
||||
? { [P in TId]?: C }
|
||||
: never
|
||||
: never)
|
||||
>
|
||||
|
||||
/** Auto-extracted from coreExtensions' toolFactories.webSearch declarations */
|
||||
export type WebSearchToolConfigMap = ExtractToolConfigMap<(typeof coreExtensions)[number], 'webSearch'>
|
||||
32
packages/aiCore/src/core/providers/types/toolFactory.ts
Normal file
32
packages/aiCore/src/core/providers/types/toolFactory.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
/**
|
||||
* 跨 provider 的工具能力标识
|
||||
*
|
||||
* 各 SDK 的工具键名不同(OpenAI: webSearch, Anthropic: webSearch_20250305, Google: googleSearch),
|
||||
* 但表达的是同一种能力。Plugin 通过 ToolCapability 进行跨 provider 统一查找。
|
||||
*/
|
||||
export type ToolCapability = 'webSearch' | 'fileSearch' | 'codeExecution' | 'urlContext'
|
||||
|
||||
/** 工具工厂返回的 patch,描述要合并到 params 的修改 */
|
||||
export interface ToolFactoryPatch {
|
||||
tools?: ToolSet
|
||||
providerOptions?: Record<string, any>
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具工厂函数 — 形状约束
|
||||
*
|
||||
* 使用 `...args: any[]` 而非 `config: Record<string, any>`,
|
||||
* 这样 `as const satisfies` 不会擦除声明时的具体 config 类型。
|
||||
* `ExtractToolConfig` 可从声明中提取具体 config 类型。
|
||||
*/
|
||||
export type ToolFactory<TProvider extends ProviderV3 = ProviderV3> = (
|
||||
provider: TProvider
|
||||
) => (...args: any[]) => ToolFactoryPatch
|
||||
|
||||
/** Map of ToolCapability keys to their factory functions. */
|
||||
export type ToolFactoryMap<TProvider extends ProviderV3 = ProviderV3> = {
|
||||
[K in ToolCapability]?: ToolFactory<TProvider>
|
||||
}
|
||||
@@ -5,11 +5,10 @@
|
||||
*/
|
||||
|
||||
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { createMockImageModel, createMockLanguageModel, createMockProviderV3, mockProviderConfigs } from '@test-utils'
|
||||
import { generateImage, generateText, streamText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../../../__tests__'
|
||||
import { globalModelResolver } from '../../models'
|
||||
import { ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
@@ -29,31 +28,15 @@ vi.mock('ai', async (importOriginal) => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn(),
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
vi.mock('../../models', () => ({
|
||||
globalModelResolver: {
|
||||
resolveLanguageModel: vi.fn(),
|
||||
resolveImageModel: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor - Model Resolution', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let executor: RuntimeExecutor
|
||||
let mockLanguageModel: LanguageModelV3
|
||||
let mockImageModel: ImageModelV3
|
||||
let mockProvider: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
specificationVersion: 'v3',
|
||||
provider: 'openai',
|
||||
@@ -66,8 +49,14 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockResolvedValue(mockImageModel)
|
||||
mockProvider = createMockProviderV3({
|
||||
provider: 'openai',
|
||||
languageModel: vi.fn(() => mockLanguageModel),
|
||||
imageModel: vi.fn(() => mockImageModel)
|
||||
})
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
|
||||
|
||||
vi.mocked(generateText).mockResolvedValue({
|
||||
text: 'Test response',
|
||||
finishReason: 'stop',
|
||||
@@ -89,61 +78,16 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
})
|
||||
|
||||
describe('Language Model Resolution (String modelId)', () => {
|
||||
it('should resolve string modelId using globalModelResolver', async () => {
|
||||
it('should resolve string modelId through provider', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
mockProviderConfigs.openai
|
||||
)
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
})
|
||||
|
||||
it('should pass provider settings to model resolver', async () => {
|
||||
const customExecutor = RuntimeExecutor.create('anthropic', {
|
||||
apiKey: 'sk-test',
|
||||
baseURL: 'https://api.anthropic.com'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(mockLanguageModel)
|
||||
|
||||
await customExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('claude-3-5-sonnet', 'anthropic', {
|
||||
apiKey: 'sk-test',
|
||||
baseURL: 'https://api.anthropic.com'
|
||||
})
|
||||
})
|
||||
|
||||
it('should resolve traditional format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4-turbo',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4-turbo', 'openai', expect.any(Object))
|
||||
})
|
||||
|
||||
it('should resolve namespaced format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'aihubmix|anthropic|claude-3',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'aihubmix|anthropic|claude-3',
|
||||
'openai',
|
||||
expect.any(Object)
|
||||
)
|
||||
})
|
||||
|
||||
it('should use resolved model for generation', async () => {
|
||||
it('should pass resolved model to generateText', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Hello' }]
|
||||
@@ -156,13 +100,31 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('should resolve traditional format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4-turbo',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4-turbo')
|
||||
})
|
||||
|
||||
it('should resolve namespaced format modelId', async () => {
|
||||
await executor.generateText({
|
||||
model: 'aihubmix|anthropic|claude-3',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('aihubmix|anthropic|claude-3')
|
||||
})
|
||||
|
||||
it('should work with streamText', async () => {
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Stream test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalled()
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: mockLanguageModel
|
||||
@@ -184,8 +146,8 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
// Should NOT call resolver for direct model
|
||||
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
|
||||
// Should NOT call provider for direct model
|
||||
expect(mockProvider.languageModel).not.toHaveBeenCalled()
|
||||
|
||||
// Should use the model directly
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
@@ -195,42 +157,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('should accept V2 model object without validation (plugin engine handles it)', async () => {
|
||||
const v2Model = {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
doGenerate: vi.fn()
|
||||
} as any
|
||||
|
||||
// The plugin engine accepts any model object directly without validation
|
||||
// V3 validation only happens when resolving string modelIds
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v2Model,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept any model object without checking specification version', async () => {
|
||||
const v2Model = {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'custom-provider',
|
||||
modelId: 'custom-model',
|
||||
doGenerate: vi.fn()
|
||||
} as any
|
||||
|
||||
// Direct model objects bypass validation
|
||||
// The executor trusts that plugins/users provide valid models
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v2Model,
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
).resolves.toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept model object with streamText', async () => {
|
||||
const directModel = createMockLanguageModel({
|
||||
specificationVersion: 'v3'
|
||||
@@ -241,7 +167,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
messages: [{ role: 'user', content: 'Stream' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).not.toHaveBeenCalled()
|
||||
expect(mockProvider.languageModel).not.toHaveBeenCalled()
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: directModel
|
||||
@@ -251,13 +177,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
})
|
||||
|
||||
describe('Image Model Resolution', () => {
|
||||
it('should resolve string image modelId using globalModelResolver', async () => {
|
||||
it('should resolve string image modelId through provider', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A beautiful sunset'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('dall-e-3', 'openai')
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
})
|
||||
|
||||
it('should accept direct ImageModelV3 object', async () => {
|
||||
@@ -272,7 +198,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).not.toHaveBeenCalled()
|
||||
expect(mockProvider.imageModel).not.toHaveBeenCalled()
|
||||
expect(generateImage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: directImageModel
|
||||
@@ -286,12 +212,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
prompt: 'Namespaced image'
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveImageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3', 'openai')
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith('aihubmix|openai|dall-e-3')
|
||||
})
|
||||
|
||||
it('should throw ImageModelResolutionError on resolution failure', async () => {
|
||||
const resolutionError = new Error('Model not found')
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(resolutionError)
|
||||
mockProvider.imageModel.mockImplementation(() => {
|
||||
throw new Error('Model not found')
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateImage({
|
||||
@@ -302,7 +229,9 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
})
|
||||
|
||||
it('should include modelId and providerId in ImageModelResolutionError', async () => {
|
||||
vi.mocked(globalModelResolver.resolveImageModel).mockRejectedValue(new Error('Not found'))
|
||||
mockProvider.imageModel.mockImplementation(() => {
|
||||
throw new Error('Not found')
|
||||
})
|
||||
|
||||
try {
|
||||
await executor.generateImage({
|
||||
@@ -337,101 +266,70 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
|
||||
describe('Provider-Specific Model Resolution', () => {
|
||||
it('should resolve models for OpenAI provider', async () => {
|
||||
const openaiExecutor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
const openaiModel = createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' })
|
||||
const openaiProvider = createMockProviderV3({
|
||||
provider: 'openai',
|
||||
languageModel: vi.fn(() => openaiModel)
|
||||
})
|
||||
const openaiExecutor = RuntimeExecutor.create('openai', openaiProvider, mockProviderConfigs.openai)
|
||||
|
||||
await openaiExecutor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('gpt-4', 'openai', expect.any(Object))
|
||||
expect(openaiProvider.languageModel).toHaveBeenCalledWith('gpt-4')
|
||||
})
|
||||
|
||||
it('should resolve models for Anthropic provider', async () => {
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
|
||||
const anthropicModel = createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet' })
|
||||
const anthropicProvider = createMockProviderV3({
|
||||
provider: 'anthropic',
|
||||
languageModel: vi.fn(() => anthropicModel)
|
||||
})
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic)
|
||||
|
||||
await anthropicExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'claude-3-5-sonnet',
|
||||
'anthropic',
|
||||
expect.any(Object)
|
||||
)
|
||||
expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet')
|
||||
})
|
||||
|
||||
it('should resolve models for Google provider', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
|
||||
const googleModel = createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash' })
|
||||
const googleProvider = createMockProviderV3({
|
||||
provider: 'google',
|
||||
languageModel: vi.fn(() => googleModel)
|
||||
})
|
||||
const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google)
|
||||
|
||||
await googleExecutor.generateText({
|
||||
model: 'gemini-2.0-flash',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gemini-2.0-flash',
|
||||
'google',
|
||||
expect.any(Object)
|
||||
)
|
||||
expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash')
|
||||
})
|
||||
|
||||
it('should resolve models for OpenAI-compatible provider', async () => {
|
||||
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(mockProviderConfigs['openai-compatible'])
|
||||
const compatModel = createMockLanguageModel({ provider: 'openai-compatible', modelId: 'custom-model' })
|
||||
const compatProvider = createMockProviderV3({
|
||||
provider: 'openai-compatible',
|
||||
languageModel: vi.fn(() => compatModel)
|
||||
})
|
||||
const compatibleExecutor = RuntimeExecutor.createOpenAICompatible(
|
||||
compatProvider,
|
||||
mockProviderConfigs['openai-compatible']
|
||||
)
|
||||
|
||||
await compatibleExecutor.generateText({
|
||||
model: 'custom-model',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'custom-model',
|
||||
'openai-compatible',
|
||||
expect.any(Object)
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI Mode Handling', () => {
|
||||
it('should pass mode setting to model resolver', async () => {
|
||||
const executorWithMode = RuntimeExecutor.create('openai', {
|
||||
...mockProviderConfigs.openai,
|
||||
mode: 'chat'
|
||||
})
|
||||
|
||||
await executorWithMode.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.objectContaining({
|
||||
mode: 'chat'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle responses mode', async () => {
|
||||
const executorWithMode = RuntimeExecutor.create('openai', {
|
||||
...mockProviderConfigs.openai,
|
||||
mode: 'responses'
|
||||
})
|
||||
|
||||
await executorWithMode.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith(
|
||||
'gpt-4',
|
||||
'openai',
|
||||
expect.objectContaining({
|
||||
mode: 'responses'
|
||||
})
|
||||
)
|
||||
expect(compatProvider.languageModel).toHaveBeenCalledWith('custom-model')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -442,11 +340,13 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
messages: [{ role: 'user', content: 'Test' }]
|
||||
})
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledWith('', 'openai', expect.any(Object))
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledWith('')
|
||||
})
|
||||
|
||||
it('should handle model resolution errors gracefully', async () => {
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockRejectedValue(new Error('Model not found'))
|
||||
mockProvider.languageModel.mockImplementation(() => {
|
||||
throw new Error('Model not found')
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateText({
|
||||
@@ -465,7 +365,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
|
||||
await Promise.all(promises)
|
||||
|
||||
expect(globalModelResolver.resolveLanguageModel).toHaveBeenCalledTimes(3)
|
||||
expect(mockProvider.languageModel).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should accept model object even without specificationVersion', async () => {
|
||||
@@ -476,7 +376,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
} as any
|
||||
|
||||
// Plugin engine doesn't validate direct model objects
|
||||
// It's the user's responsibility to provide valid models
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: invalidModel,
|
||||
@@ -492,7 +391,7 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
specificationVersion: 'v3'
|
||||
})
|
||||
|
||||
vi.mocked(globalModelResolver.resolveLanguageModel).mockResolvedValue(v3Model)
|
||||
mockProvider.languageModel.mockReturnValue(v3Model)
|
||||
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
@@ -516,7 +415,6 @@ describe('RuntimeExecutor - Model Resolution', () => {
|
||||
} as any
|
||||
|
||||
// Direct models bypass validation in the plugin engine
|
||||
// Only resolved models (from string IDs) are validated
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: v1Model,
|
||||
|
||||
@@ -1,52 +1,56 @@
|
||||
import type { ImageModelV3 } from '@ai-sdk/provider'
|
||||
import { createMockImageModel, createMockProviderV3 } from '@test-utils'
|
||||
import { generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('ai', () => ({
|
||||
experimental_generateImage: vi.fn(),
|
||||
generateImage: vi.fn(),
|
||||
jsonSchema: vi.fn((schema) => schema),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
vi.mock('ai', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
experimental_generateImage: vi.fn(),
|
||||
generateImage: vi.fn(),
|
||||
jsonSchema: vi.fn((schema) => schema),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
})
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let executor: RuntimeExecutor
|
||||
let mockImageModel: ImageModelV3
|
||||
let mockProvider: any
|
||||
let mockGenerateImageResult: any
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock image model
|
||||
mockImageModel = {
|
||||
mockImageModel = createMockImageModel({
|
||||
modelId: 'dall-e-3',
|
||||
provider: 'openai'
|
||||
} as ImageModelV3
|
||||
})
|
||||
|
||||
// Create mock provider with imageModel as a spy
|
||||
mockProvider = createMockProviderV3({
|
||||
provider: 'openai',
|
||||
imageModel: vi.fn(() => mockImageModel)
|
||||
})
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', mockProvider, {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock generateImage result
|
||||
mockGenerateImageResult = {
|
||||
@@ -71,8 +75,6 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
responses: []
|
||||
}
|
||||
|
||||
// Setup mocks to avoid "No providers registered" error
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
@@ -80,7 +82,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
expect(mockProvider.imageModel).toHaveBeenCalledWith('dall-e-3')
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@@ -96,7 +98,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||
// Pre-created model is used directly, provider.imageModel is not called
|
||||
expect(mockProvider.imageModel).not.toHaveBeenCalled()
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
@@ -224,6 +227,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
mockProvider,
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
@@ -269,6 +273,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
mockProvider,
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
@@ -309,6 +314,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
mockProvider,
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
@@ -325,7 +331,8 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to get image model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
// Since mockProvider.imageModel is already a vi.fn() spy, we can mock it directly
|
||||
mockProvider.imageModel.mockImplementation(() => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
@@ -336,7 +343,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
it('should handle ImageModelResolutionError correctly', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
mockProvider.imageModel.mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
@@ -353,7 +360,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
it('should handle ImageModelResolutionError without provider', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('unknown-model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
mockProvider.imageModel.mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
@@ -398,6 +405,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
mockProvider,
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
@@ -436,23 +444,43 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
describe('Multiple providers support', () => {
|
||||
it('should work with different providers', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', {
|
||||
const googleImageModel = createMockImageModel({
|
||||
provider: 'google',
|
||||
modelId: 'imagen-3.0-generate-002'
|
||||
})
|
||||
|
||||
const googleProvider = createMockProviderV3({
|
||||
provider: 'google',
|
||||
imageModel: vi.fn(() => googleImageModel)
|
||||
})
|
||||
|
||||
const googleExecutor = RuntimeExecutor.create('google', googleProvider, {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
|
||||
expect(googleProvider.imageModel).toHaveBeenCalledWith('imagen-3.0-generate-002')
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||
const xaiImageModel = createMockImageModel({
|
||||
provider: 'xai',
|
||||
modelId: 'grok-2-image'
|
||||
})
|
||||
|
||||
const xaiProvider = createMockProviderV3({
|
||||
provider: 'xai',
|
||||
imageModel: vi.fn(() => xaiImageModel)
|
||||
})
|
||||
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
|
||||
expect(xaiProvider.imageModel).toHaveBeenCalledWith('grok-2-image')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -3,18 +3,18 @@
|
||||
* Tests non-streaming text generation across all providers with various parameters
|
||||
*/
|
||||
|
||||
import { generateText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
createMockLanguageModel,
|
||||
createMockProviderV3,
|
||||
mockCompleteResponses,
|
||||
mockProviderConfigs,
|
||||
testMessages,
|
||||
testTools
|
||||
} from '../../../__tests__'
|
||||
} from '@test-utils'
|
||||
import { generateText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
|
||||
@@ -26,28 +26,26 @@ vi.mock('ai', async (importOriginal) => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateText', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let executor: RuntimeExecutor
|
||||
let mockLanguageModel: any
|
||||
let mockProvider: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
mockProvider = createMockProviderV3({
|
||||
provider: 'openai',
|
||||
languageModel: vi.fn(() => mockLanguageModel)
|
||||
})
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
|
||||
|
||||
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
|
||||
})
|
||||
|
||||
@@ -231,75 +229,87 @@ describe('RuntimeExecutor.generateText', () => {
|
||||
|
||||
describe('Multiple Providers', () => {
|
||||
it('should work with Anthropic provider', async () => {
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
|
||||
|
||||
const anthropicModel = createMockLanguageModel({
|
||||
provider: 'anthropic',
|
||||
modelId: 'claude-3-5-sonnet-20241022'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
|
||||
const anthropicProvider = createMockProviderV3({
|
||||
provider: 'anthropic',
|
||||
languageModel: vi.fn(() => anthropicModel)
|
||||
})
|
||||
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', anthropicProvider, mockProviderConfigs.anthropic)
|
||||
|
||||
await anthropicExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
|
||||
expect(anthropicProvider.languageModel).toHaveBeenCalledWith('claude-3-5-sonnet-20241022')
|
||||
})
|
||||
|
||||
it('should work with Google provider', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
|
||||
|
||||
const googleModel = createMockLanguageModel({
|
||||
provider: 'google',
|
||||
modelId: 'gemini-2.0-flash-exp'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
|
||||
const googleProvider = createMockProviderV3({
|
||||
provider: 'google',
|
||||
languageModel: vi.fn(() => googleModel)
|
||||
})
|
||||
|
||||
const googleExecutor = RuntimeExecutor.create('google', googleProvider, mockProviderConfigs.google)
|
||||
|
||||
await googleExecutor.generateText({
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
|
||||
expect(googleProvider.languageModel).toHaveBeenCalledWith('gemini-2.0-flash-exp')
|
||||
})
|
||||
|
||||
it('should work with xAI provider', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
|
||||
|
||||
const xaiModel = createMockLanguageModel({
|
||||
provider: 'xai',
|
||||
modelId: 'grok-2-latest'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
|
||||
const xaiProvider = createMockProviderV3({
|
||||
provider: 'xai',
|
||||
languageModel: vi.fn(() => xaiModel)
|
||||
})
|
||||
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', xaiProvider, mockProviderConfigs.xai)
|
||||
|
||||
await xaiExecutor.generateText({
|
||||
model: 'grok-2-latest',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
|
||||
expect(xaiProvider.languageModel).toHaveBeenCalledWith('grok-2-latest')
|
||||
})
|
||||
|
||||
it('should work with DeepSeek provider', async () => {
|
||||
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
|
||||
|
||||
const deepseekModel = createMockLanguageModel({
|
||||
provider: 'deepseek',
|
||||
modelId: 'deepseek-chat'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
|
||||
const deepseekProvider = createMockProviderV3({
|
||||
provider: 'deepseek',
|
||||
languageModel: vi.fn(() => deepseekModel)
|
||||
})
|
||||
|
||||
const deepseekExecutor = RuntimeExecutor.create('deepseek', deepseekProvider, mockProviderConfigs.deepseek)
|
||||
|
||||
await deepseekExecutor.generateText({
|
||||
model: 'deepseek-chat',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
|
||||
expect(deepseekProvider.languageModel).toHaveBeenCalledWith('deepseek-chat')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -325,7 +335,9 @@ describe('RuntimeExecutor.generateText', () => {
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
|
||||
testPlugin
|
||||
])
|
||||
|
||||
const result = await executorWithPlugin.generateText({
|
||||
model: 'gpt-4',
|
||||
@@ -364,7 +376,10 @@ describe('RuntimeExecutor.generateText', () => {
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
|
||||
const executorWithPlugins = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
|
||||
plugin1,
|
||||
plugin2
|
||||
])
|
||||
|
||||
await executorWithPlugins.generateText({
|
||||
model: 'gpt-4',
|
||||
@@ -404,7 +419,9 @@ describe('RuntimeExecutor.generateText', () => {
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
|
||||
errorPlugin
|
||||
])
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.generateText({
|
||||
@@ -425,7 +442,7 @@ describe('RuntimeExecutor.generateText', () => {
|
||||
|
||||
it('should handle model not found error', async () => {
|
||||
const error = new Error('Model not found: invalid-model')
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
|
||||
mockProvider.languageModel.mockImplementationOnce(() => {
|
||||
throw error
|
||||
})
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
*/
|
||||
|
||||
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import { createMockImageModel, createMockLanguageModel, createMockMiddleware } from '@test-utils'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createMockImageModel, createMockLanguageModel, createMockMiddleware } from '../../../__tests__'
|
||||
import { ModelResolutionError, RecursiveDepthError } from '../../errors'
|
||||
import type { AiPlugin, GenerateTextParams, GenerateTextResult } from '../../plugins'
|
||||
import { PluginEngine } from '../pluginEngine'
|
||||
|
||||
@@ -3,12 +3,17 @@
|
||||
* Tests streaming text generation across all providers with various parameters
|
||||
*/
|
||||
|
||||
import {
|
||||
collectStreamChunks,
|
||||
createMockLanguageModel,
|
||||
createMockProviderV3,
|
||||
mockProviderConfigs,
|
||||
testMessages
|
||||
} from '@test-utils'
|
||||
import { streamText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
|
||||
import type { AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
|
||||
@@ -20,28 +25,25 @@ vi.mock('ai', async (importOriginal) => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.streamText', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let executor: RuntimeExecutor
|
||||
let mockLanguageModel: any
|
||||
let mockProvider: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
mockProvider = createMockProviderV3({
|
||||
provider: 'openai',
|
||||
languageModel: () => mockLanguageModel
|
||||
})
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai)
|
||||
})
|
||||
|
||||
describe('Basic Functionality', () => {
|
||||
@@ -416,7 +418,9 @@ describe('RuntimeExecutor.streamText', () => {
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
|
||||
testPlugin
|
||||
])
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
@@ -509,7 +513,9 @@ describe('RuntimeExecutor.streamText', () => {
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProvider, mockProviderConfigs.openai, [
|
||||
errorPlugin
|
||||
])
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.streamText({
|
||||
@@ -519,11 +525,12 @@ describe('RuntimeExecutor.streamText', () => {
|
||||
).rejects.toThrow('Stream error')
|
||||
|
||||
// onError receives the original error and context with core fields
|
||||
// context.model is the resolved LanguageModel (updated after resolveModel hook)
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
model: 'gpt-4'
|
||||
model: expect.objectContaining({ modelId: 'gpt-4' })
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
@@ -2,34 +2,54 @@
|
||||
* 运行时执行器
|
||||
* 专注于插件化的AI调用处理
|
||||
*/
|
||||
import type { ImageModelV3, LanguageModelV3 } from '@ai-sdk/provider'
|
||||
import type { ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
import { generateImage as _generateImage, generateText as _generateText, streamText as _streamText } from 'ai'
|
||||
import {
|
||||
createProviderRegistry,
|
||||
embedMany as _embedMany,
|
||||
generateImage as _generateImage,
|
||||
generateText as _generateText,
|
||||
streamText as _streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { isV3Model } from '../models/utils'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { type AiPlugin, definePlugin } from '../plugins'
|
||||
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import type { generateImageParams, generateTextParams, RuntimeConfig, streamTextParams } from './types'
|
||||
import type {
|
||||
EmbedManyParams,
|
||||
EmbedManyResult,
|
||||
generateImageParams,
|
||||
generateImageResult,
|
||||
generateTextParams,
|
||||
RuntimeConfig,
|
||||
streamTextParams
|
||||
} from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
export class RuntimeExecutor<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
// private options: ProviderSettingsMap[T]
|
||||
private config: RuntimeConfig<T>
|
||||
private config: RuntimeConfig<TSettingsMap, T>
|
||||
private registry: ReturnType<typeof createProviderRegistry>
|
||||
|
||||
constructor(config: RuntimeConfig<T>) {
|
||||
// if (!isProviderSupported(config.providerId)) {
|
||||
// throw new Error(`Unsupported provider: ${config.providerId}`)
|
||||
// }
|
||||
|
||||
// 存储options供后续使用
|
||||
// this.options = config.options
|
||||
constructor(config: RuntimeConfig<TSettingsMap, T>) {
|
||||
this.config = config
|
||||
// 创建插件客户端
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
|
||||
// Some v3 providers (e.g., @openrouter/ai-sdk-provider) expose textEmbeddingModel
|
||||
// but not embeddingModel. Patch for AI SDK registry compatibility.
|
||||
const provider = config.provider
|
||||
if (!provider.embeddingModel && provider.textEmbeddingModel) {
|
||||
provider.embeddingModel = (modelId: string) => provider.textEmbeddingModel!(modelId)
|
||||
}
|
||||
|
||||
this.registry = createProviderRegistry({
|
||||
[config.providerId]: provider
|
||||
})
|
||||
}
|
||||
|
||||
private createResolveModelPlugin() {
|
||||
@@ -58,8 +78,9 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
private createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
context.executor = this
|
||||
configureContext: async () => {
|
||||
// Placeholder for future context configuration
|
||||
// Previously set executor and baseProvider, now handled by registry
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -120,7 +141,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
/**
|
||||
* 生成图像
|
||||
*/
|
||||
generateImage(params: generateImageParams): Promise<ReturnType<typeof _generateImage>> {
|
||||
async generateImage(params: generateImageParams): Promise<generateImageResult> {
|
||||
try {
|
||||
const { model } = params
|
||||
|
||||
@@ -148,19 +169,39 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量嵌入文本
|
||||
*/
|
||||
async embedMany(params: EmbedManyParams): Promise<EmbedManyResult> {
|
||||
const { model: modelOrId, ...options } = params
|
||||
|
||||
// 解析 embedding 模型
|
||||
const embeddingModel =
|
||||
typeof modelOrId === 'string'
|
||||
? this.registry.embeddingModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
|
||||
: modelOrId
|
||||
|
||||
return _embedMany({
|
||||
model: embeddingModel,
|
||||
...options
|
||||
})
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
* 解析模型:将字符串 modelId 解析为 model 对象
|
||||
* middleware 的应用由 pluginEngine 统一处理
|
||||
*
|
||||
* 对于有 modelResolver 的配置(如 xAI responses, OpenAI chat),
|
||||
* 使用 resolver 函数解析模型,而不是通过 registry.languageModel()。
|
||||
* resolver 在 extension 声明处类型安全地捕获了具体 provider 方法。
|
||||
*/
|
||||
private async resolveModel(modelOrId: LanguageModel): Promise<LanguageModelV3> {
|
||||
if (typeof modelOrId === 'string') {
|
||||
return await globalModelResolver.resolveLanguageModel(
|
||||
modelOrId,
|
||||
this.config.providerId,
|
||||
this.config.providerSettings
|
||||
)
|
||||
if (this.config.modelResolver) {
|
||||
return this.config.modelResolver(modelOrId)
|
||||
}
|
||||
return this.registry.languageModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
|
||||
} else {
|
||||
if (!isV3Model(modelOrId)) {
|
||||
throw new Error(
|
||||
@@ -178,13 +219,8 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
private async resolveImageModel(modelOrId: ImageModelV3 | string): Promise<ImageModelV3> {
|
||||
try {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,使用新的ModelResolver解析
|
||||
return await globalModelResolver.resolveImageModel(
|
||||
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
|
||||
this.config.providerId // fallback provider
|
||||
)
|
||||
return this.registry.imageModel(`${this.config.providerId}:${modelOrId}` as `${string}:${string}`)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -201,27 +237,37 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
/**
|
||||
* 创建执行器 - 支持已知provider的类型安全
|
||||
*/
|
||||
static create<T extends ProviderId>(
|
||||
static create<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: ModelConfig<T>['providerSettings'],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return new RuntimeExecutor({
|
||||
provider: ProviderV3,
|
||||
options: TSettingsMap[T],
|
||||
plugins?: AiPlugin[],
|
||||
modelResolver?: (modelId: string) => any
|
||||
): RuntimeExecutor<TSettingsMap, T> {
|
||||
return new RuntimeExecutor<TSettingsMap, T>({
|
||||
providerId,
|
||||
provider,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
plugins,
|
||||
modelResolver
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
* ✅ Now accepts provider instance directly
|
||||
*/
|
||||
static createOpenAICompatible(
|
||||
options: ModelConfig<'openai-compatible'>['providerSettings'],
|
||||
provider: ProviderV3, // ✅ Accept provider instance
|
||||
options: CoreProviderSettingsMap['openai-compatible'],
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return new RuntimeExecutor({
|
||||
): RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'> {
|
||||
return new RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'>({
|
||||
providerId: 'openai-compatible',
|
||||
provider, // ✅ Pass provider to config
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
|
||||
@@ -7,76 +7,113 @@
|
||||
export { RuntimeExecutor } from './executor'
|
||||
|
||||
// 导出类型
|
||||
export type { RuntimeConfig } from './types'
|
||||
export type { EmbedManyParams, EmbedManyResult, RuntimeConfig } from './types'
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||
import { extensionRegistry } from '../providers'
|
||||
import { type CoreProviderSettingsMap, type StringKeys } from '../providers/types'
|
||||
import { RuntimeExecutor } from './executor'
|
||||
|
||||
/**
|
||||
* 创建运行时执行器 - 支持类型安全的已知provider
|
||||
* 自动确保 provider 已初始化
|
||||
*/
|
||||
export function createExecutor<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return RuntimeExecutor.create(providerId, options, plugins)
|
||||
}
|
||||
export async function createExecutor<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(providerId: T, options: TSettingsMap[T], plugins?: AiPlugin[]): Promise<RuntimeExecutor<TSettingsMap, T>> {
|
||||
if (!extensionRegistry.has(providerId)) {
|
||||
throw new Error(`Provider extension "${providerId}" not registered`)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
export function createOpenAICompatibleExecutor(
|
||||
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||
}
|
||||
const provider = await extensionRegistry.createProvider(providerId, options || {})
|
||||
|
||||
// === 直接调用API(无需创建executor实例)===
|
||||
// Extract model resolver from variant's resolveModel declaration (type-safe at extension level)
|
||||
const resolver = extensionRegistry.getModelResolver(providerId as string)
|
||||
const modelResolver = resolver ? (modelId: string) => resolver(provider, modelId) : undefined
|
||||
|
||||
return RuntimeExecutor.create<TSettingsMap, T>(providerId, provider, options, plugins, modelResolver)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接流式文本生成
|
||||
*/
|
||||
export async function streamText<T extends ProviderId>(
|
||||
export async function streamText<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
|
||||
options: TSettingsMap[T],
|
||||
params: Parameters<RuntimeExecutor<TSettingsMap, T>['streamText']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['streamText']>> {
|
||||
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
|
||||
return executor.streamText(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成文本
|
||||
*/
|
||||
export async function generateText<T extends ProviderId>(
|
||||
export async function generateText<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
|
||||
options: TSettingsMap[T],
|
||||
params: Parameters<RuntimeExecutor<TSettingsMap, T>['generateText']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['generateText']>> {
|
||||
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
|
||||
return executor.generateText(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成图像 - 支持middlewares
|
||||
*/
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
export async function generateImage<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
|
||||
options: TSettingsMap[T],
|
||||
params: Parameters<RuntimeExecutor<TSettingsMap, T>['generateImage']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['generateImage']>> {
|
||||
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
|
||||
return executor.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接批量嵌入文本
|
||||
* AI SDK v6 只有 embedMany,没有 embed
|
||||
*/
|
||||
export async function embedMany<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
>(
|
||||
providerId: T,
|
||||
options: TSettingsMap[T],
|
||||
params: Parameters<RuntimeExecutor<TSettingsMap, T>['embedMany']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<TSettingsMap, T>['embedMany']>> {
|
||||
const executor = await createExecutor<TSettingsMap, T>(providerId, options, plugins)
|
||||
return executor.embedMany(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 OpenAI Compatible 执行器
|
||||
*/
|
||||
export async function createOpenAICompatibleExecutor(
|
||||
options: CoreProviderSettingsMap['openai-compatible'],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<RuntimeExecutor<CoreProviderSettingsMap, 'openai-compatible'>> {
|
||||
const provider = await extensionRegistry.createProvider('openai-compatible', options)
|
||||
|
||||
return RuntimeExecutor.createOpenAICompatible(provider, options, plugins)
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
// 未来将在 ../agents/ 文件夹中添加:
|
||||
// - AgentExecutor.ts
|
||||
|
||||
@@ -14,13 +14,13 @@ import {
|
||||
type StreamTextParams,
|
||||
type StreamTextResult
|
||||
} from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
import type { RegisteredProviderId } from '../providers'
|
||||
|
||||
/**
|
||||
* 插件增强的 AI 客户端
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
export class PluginEngine<T extends string = RegisteredProviderId> {
|
||||
/**
|
||||
* Plugin storage with explicit any/any generics
|
||||
*
|
||||
@@ -36,7 +36,6 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
// private readonly options: ProviderSettingsMap[T],
|
||||
plugins: AiPlugin[] = []
|
||||
) {
|
||||
this.basePlugins = plugins
|
||||
@@ -352,6 +351,9 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
throw new ModelResolutionError(modelId, this.providerId)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
// 更新 context.model 为已解析的 LanguageModel 实例
|
||||
// 后续 plugin(如 providerToolPlugin)需要 model.provider 来识别聚合供应商的协议
|
||||
context.model = resolvedModel
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
@@ -359,7 +361,10 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 2.5 应用 context.middlewares 到模型
|
||||
if (typeof model !== 'string' && context.middlewares && context.middlewares.length > 0) {
|
||||
if (context.middlewares && context.middlewares.length > 0) {
|
||||
if (typeof resolvedModel === 'string') {
|
||||
throw new Error(`Model must be resolved before applying middlewares, got string: ${resolvedModel}`)
|
||||
}
|
||||
resolvedModel = wrapLanguageModel({
|
||||
model: resolvedModel as LanguageModelV3,
|
||||
middleware: context.middlewares
|
||||
|
||||
@@ -1,24 +1,43 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import type { ImageModelV3 } from '@ai-sdk/provider'
|
||||
import type { generateImage, generateText, streamText } from 'ai'
|
||||
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { embedMany, generateImage, generateText, streamText } from 'ai'
|
||||
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 运行时执行器配置
|
||||
*
|
||||
* @typeParam TSettingsMap - Provider Settings Map(默认 CoreProviderSettingsMap)
|
||||
* @typeParam T - Provider ID 类型(从 TSettingsMap 的键推断)
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
export interface RuntimeConfig<
|
||||
TSettingsMap extends Record<string, any> = CoreProviderSettingsMap,
|
||||
T extends StringKeys<TSettingsMap> = StringKeys<TSettingsMap>
|
||||
> {
|
||||
providerId: T
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
provider: ProviderV3
|
||||
providerSettings: TSettingsMap[T]
|
||||
plugins?: AiPlugin[]
|
||||
/**
|
||||
* 模型解析函数
|
||||
* 从 variant 的 resolveModel 声明中提取(类型安全在 extension 声明处保证)。
|
||||
* 不提供时使用 AI SDK 默认的 provider.languageModel()。
|
||||
*/
|
||||
modelResolver?: (modelId: string) => any
|
||||
}
|
||||
|
||||
export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'model'> & {
|
||||
model: string | ImageModelV3
|
||||
}
|
||||
export type generateImageResult = Awaited<ReturnType<typeof generateImage>>
|
||||
export type generateTextParams = Parameters<typeof generateText>[0]
|
||||
export type streamTextParams = Parameters<typeof streamText>[0]
|
||||
|
||||
// Embedding types (AI SDK v6 only has embedMany, no embed)
|
||||
export type EmbedManyParams = Omit<Parameters<typeof embedMany>[0], 'model'> & {
|
||||
model: string | EmbeddingModelV3
|
||||
}
|
||||
export type EmbedManyResult = Awaited<ReturnType<typeof embedMany>>
|
||||
|
||||
1
packages/aiCore/src/core/types/index.ts
Normal file
1
packages/aiCore/src/core/types/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export type PlainObject = Record<string, any>
|
||||
17
packages/aiCore/src/core/utils/index.ts
Normal file
17
packages/aiCore/src/core/utils/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import type { PlainObject } from '../types'
|
||||
|
||||
export const isPlainObject = (value: unknown): value is PlainObject => {
|
||||
return typeof value === 'object' && value !== null && !Array.isArray(value)
|
||||
}
|
||||
|
||||
export function deepMergeObjects<T extends PlainObject>(target: T, source: PlainObject): T {
|
||||
const result: PlainObject = { ...target }
|
||||
Object.entries(source).forEach(([key, value]) => {
|
||||
if (isPlainObject(value) && isPlainObject(result[key])) {
|
||||
result[key] = deepMergeObjects(result[key], value)
|
||||
} else {
|
||||
result[key] = value
|
||||
}
|
||||
})
|
||||
return result as T
|
||||
}
|
||||
@@ -6,46 +6,38 @@
|
||||
// 导入内部使用的类和函数
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateImage,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
export { createExecutor, embedMany, generateImage, generateText, streamText } from './core/runtime'
|
||||
|
||||
// ==================== Embedding 类型 ====================
|
||||
export type { EmbedManyParams, EmbedManyResult } from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { isV2Model, isV3Model, globalModelResolver as modelResolver } from './core/models'
|
||||
export { isV2Model, isV3Model } from './core/models'
|
||||
|
||||
// ==================== 插件系统 ====================
|
||||
export type {
|
||||
AiPlugin,
|
||||
AiRequestContext,
|
||||
AiRequestMetadata,
|
||||
GenerateTextParams,
|
||||
GenerateTextResult,
|
||||
HookResult,
|
||||
PluginManagerConfig,
|
||||
RecursiveCallFn,
|
||||
StreamTextParams,
|
||||
StreamTextResult
|
||||
} from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
export { definePlugin } from './core/plugins'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== 类型工具 ====================
|
||||
export type { AiSdkModel } from './core/providers'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
createAnthropicOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
type ExtractProviderOptions,
|
||||
mergeProviderOptions,
|
||||
type ProviderOptionsMap,
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
export type {
|
||||
AiSdkModel,
|
||||
ExtractToolConfig,
|
||||
ExtractToolConfigMap,
|
||||
ProviderId,
|
||||
ToolCapability,
|
||||
ToolFactory,
|
||||
ToolFactoryMap,
|
||||
ToolFactoryPatch,
|
||||
WebSearchToolConfigMap
|
||||
} from './core/providers'
|
||||
|
||||
// ==================== 错误处理 ====================
|
||||
export {
|
||||
@@ -53,11 +45,6 @@ export {
|
||||
ModelResolutionError,
|
||||
ParameterValidationError,
|
||||
PluginExecutionError,
|
||||
ProviderConfigError,
|
||||
RecursiveDepthError,
|
||||
TemplateLoadError
|
||||
} from './core/errors'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
export type { AiPlugin, AiRequestContext } from './core/plugins/types'
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/**
|
||||
* Test Utilities
|
||||
* Helper functions for testing AI Core functionality
|
||||
* Common Test Utilities
|
||||
* General-purpose helper functions for testing
|
||||
*/
|
||||
|
||||
import { expect, vi } from 'vitest'
|
||||
|
||||
import type { ProviderId } from '../fixtures/mock-providers'
|
||||
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
|
||||
import type { ProviderId } from '../mocks/providers'
|
||||
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../mocks/providers'
|
||||
|
||||
/**
|
||||
* Creates a test provider with streaming support
|
||||
@@ -10,15 +10,14 @@ import type {
|
||||
LanguageModelV3Middleware,
|
||||
ProviderV3
|
||||
} from '@ai-sdk/provider'
|
||||
import type { ProviderId } from '@test-utils'
|
||||
import type { Tool, ToolSet } from 'ai'
|
||||
import { tool } from 'ai'
|
||||
import { MockLanguageModelV3 } from 'ai/test'
|
||||
import { vi } from 'vitest'
|
||||
import * as z from 'zod'
|
||||
|
||||
import type { StreamTextParams, StreamTextResult } from '../../core/plugins'
|
||||
import type { ProviderId } from '../../core/providers/types'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../src/core/plugins/types'
|
||||
|
||||
/**
|
||||
* Type for partial overrides that allows omitting the model field
|
||||
@@ -60,7 +59,8 @@ export function createMockContext(overrides?: ContextOverrides): AiRequestContex
|
||||
isRecursiveCall: false,
|
||||
recursiveDepth: 0,
|
||||
maxRecursiveDepth: 10,
|
||||
extensions: new Map()
|
||||
extensions: new Map(),
|
||||
pluginState: {}
|
||||
}
|
||||
|
||||
if (overrides) {
|
||||
351
packages/aiCore/test_utils/helpers/model.ts
Normal file
351
packages/aiCore/test_utils/helpers/model.ts
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* Model Test Utilities
|
||||
* Provides comprehensive mock creators for AI SDK v3 models and related test utilities
|
||||
*/
|
||||
|
||||
import type {
|
||||
EmbeddingModelV3,
|
||||
ImageModelV3,
|
||||
LanguageModelV3,
|
||||
LanguageModelV3Middleware,
|
||||
ProviderV3
|
||||
} from '@ai-sdk/provider'
|
||||
import type { Tool, ToolSet } from 'ai'
|
||||
import { tool } from 'ai'
|
||||
import { MockLanguageModelV3 } from 'ai/test'
|
||||
import { vi } from 'vitest'
|
||||
import * as z from 'zod'
|
||||
|
||||
import type { StreamTextParams, StreamTextResult } from '../../src/core/plugins'
|
||||
import type { RegisteredProviderId } from '../../src/core/providers/types'
|
||||
import type { AiRequestContext } from '../../src/types'
|
||||
|
||||
/**
|
||||
* Type for partial overrides that allows omitting the model field
|
||||
* The model will be automatically added by createMockContext
|
||||
*/
|
||||
type ContextOverrides = Partial<Omit<AiRequestContext<StreamTextParams, StreamTextResult>, 'originalParams'>> & {
|
||||
originalParams?: Partial<Omit<StreamTextParams, 'model'>> & { model?: StreamTextParams['model'] }
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock AiRequestContext with type safety
|
||||
* The model field is automatically added to originalParams if not provided
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const context = createMockContext({
|
||||
* providerId: 'openai',
|
||||
* metadata: { requestId: 'test-123' }
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockContext(overrides?: ContextOverrides): AiRequestContext<StreamTextParams, StreamTextResult> {
|
||||
const mockModel = new MockLanguageModelV3({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
})
|
||||
|
||||
const base: AiRequestContext<StreamTextParams, StreamTextResult> = {
|
||||
providerId: 'openai' as RegisteredProviderId,
|
||||
model: mockModel,
|
||||
originalParams: {
|
||||
model: mockModel,
|
||||
messages: [{ role: 'user', content: 'Test message' }]
|
||||
} as StreamTextParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: 'test-request-id',
|
||||
recursiveCall: vi.fn(),
|
||||
isRecursiveCall: false,
|
||||
recursiveDepth: 0,
|
||||
maxRecursiveDepth: 10,
|
||||
extensions: new Map(),
|
||||
pluginState: {}
|
||||
}
|
||||
|
||||
if (overrides) {
|
||||
// Ensure model is always present in originalParams
|
||||
const mergedOriginalParams = {
|
||||
...base.originalParams,
|
||||
...overrides.originalParams,
|
||||
model: overrides.originalParams?.model ?? mockModel
|
||||
}
|
||||
|
||||
return {
|
||||
...base,
|
||||
...overrides,
|
||||
originalParams: mergedOriginalParams as StreamTextParams
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock embedding model with customizable behavior
|
||||
* Compliant with AI SDK v3 specification
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const embeddingModel = createMockEmbeddingModel({
|
||||
* provider: 'openai',
|
||||
* modelId: 'text-embedding-3-small',
|
||||
* maxEmbeddingsPerCall: 2048
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockEmbeddingModel(overrides?: Partial<EmbeddingModelV3>): EmbeddingModelV3 {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-embedding-model',
|
||||
maxEmbeddingsPerCall: 100,
|
||||
supportsParallelCalls: true,
|
||||
|
||||
doEmbed: vi.fn().mockResolvedValue({
|
||||
embeddings: [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
],
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
totalTokens: 10
|
||||
},
|
||||
rawResponse: { headers: {} }
|
||||
}),
|
||||
|
||||
...overrides
|
||||
} as EmbeddingModelV3
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a complete mock ProviderV3 with all model types
|
||||
* Useful for testing provider registration and management
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const provider = createMockProviderV3({
|
||||
* provider: 'openai',
|
||||
* languageModel: customLanguageModel,
|
||||
* imageModel: customImageModel
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockProviderV3(overrides?: {
|
||||
provider?: string
|
||||
languageModel?: (modelId: string) => LanguageModelV3
|
||||
imageModel?: (modelId: string) => ImageModelV3
|
||||
embeddingModel?: (modelId: string) => EmbeddingModelV3
|
||||
}): ProviderV3 {
|
||||
const defaultLanguageModel = (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
defaultObjectGenerationMode: 'tool',
|
||||
supportedUrls: {},
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
text: 'Mock response text',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30,
|
||||
inputTokenDetails: {},
|
||||
outputTokenDetails: {}
|
||||
},
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
}),
|
||||
doStream: vi.fn().mockReturnValue({
|
||||
stream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Mock ' }
|
||||
yield { type: 'text-delta', textDelta: 'streaming ' }
|
||||
yield { type: 'text-delta', textDelta: 'response' }
|
||||
yield {
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 15,
|
||||
totalTokens: 25,
|
||||
inputTokenDetails: {},
|
||||
outputTokenDetails: {}
|
||||
}
|
||||
}
|
||||
})(),
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
})
|
||||
}) as LanguageModelV3
|
||||
|
||||
const defaultImageModel = (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
maxImagesPerCall: undefined,
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
images: [
|
||||
{
|
||||
base64: 'mock-base64-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
|
||||
mimeType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: []
|
||||
})
|
||||
}) as ImageModelV3
|
||||
|
||||
const defaultEmbeddingModel = (modelId: string) =>
|
||||
({
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
modelId,
|
||||
maxEmbeddingsPerCall: 100,
|
||||
supportsParallelCalls: true,
|
||||
doEmbed: vi.fn().mockResolvedValue({
|
||||
embeddings: [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
],
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
totalTokens: 10
|
||||
},
|
||||
rawResponse: { headers: {} }
|
||||
})
|
||||
}) as EmbeddingModelV3
|
||||
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
provider: overrides?.provider ?? 'mock-provider',
|
||||
|
||||
languageModel: vi.fn(overrides?.languageModel ?? defaultLanguageModel),
|
||||
imageModel: vi.fn(overrides?.imageModel ?? defaultImageModel),
|
||||
embeddingModel: vi.fn(overrides?.embeddingModel ?? defaultEmbeddingModel)
|
||||
} as ProviderV3
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock middleware for testing middleware chains
|
||||
* Supports both generate and stream wrapping
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const middleware = createMockMiddleware({
|
||||
* name: 'test-middleware'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockMiddleware(): LanguageModelV3Middleware {
|
||||
return {
|
||||
specificationVersion: 'v3',
|
||||
wrapGenerate: vi.fn((doGenerate) => doGenerate),
|
||||
wrapStream: vi.fn((doStream) => doStream)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a type-safe function tool for testing using AI SDK's tool() function
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const weatherTool = createMockTool('getWeather', 'Get current weather')
|
||||
* ```
|
||||
*/
|
||||
export function createMockTool(name: string, description?: string): Tool<{ value?: string }, string> {
|
||||
return tool({
|
||||
description: description || `Mock tool: ${name}`,
|
||||
inputSchema: z.object({
|
||||
value: z.string().optional()
|
||||
}),
|
||||
execute: vi.fn(async () => 'mock result')
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a provider-defined tool for testing
|
||||
*/
|
||||
export function createMockProviderTool(name: string, description?: string): { type: 'provider'; description: string } {
|
||||
return {
|
||||
type: 'provider' as const,
|
||||
description: description || `Mock provider tool: ${name}`
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a ToolSet with multiple tools
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const tools = createMockToolSet({
|
||||
* getWeather: 'function',
|
||||
* searchDatabase: 'function',
|
||||
* nativeSearch: 'provider'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockToolSet(tools: Record<string, 'function' | 'provider'>): ToolSet {
|
||||
const toolSet: ToolSet = {}
|
||||
|
||||
for (const [name, type] of Object.entries(tools)) {
|
||||
if (type === 'function') {
|
||||
toolSet[name] = createMockTool(name)
|
||||
} else {
|
||||
toolSet[name] = createMockProviderTool(name) as Tool
|
||||
}
|
||||
}
|
||||
|
||||
return toolSet
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock stream params for testing
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const params = createMockStreamParams({
|
||||
* messages: [{ role: 'user', content: 'Custom message' }],
|
||||
* temperature: 0.7
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function createMockStreamParams(overrides?: Partial<StreamTextParams>): StreamTextParams {
|
||||
return {
|
||||
messages: [{ role: 'user', content: 'Test message' }],
|
||||
...overrides
|
||||
} as StreamTextParams
|
||||
}
|
||||
|
||||
/**
|
||||
* Common mock model instances for quick testing
|
||||
*/
|
||||
export const mockModels = {
|
||||
/** Standard language model for general testing */
|
||||
language: new MockLanguageModelV3({
|
||||
provider: 'test-provider',
|
||||
modelId: 'test-model'
|
||||
}),
|
||||
|
||||
/** Mock OpenAI GPT-4 model */
|
||||
gpt4: new MockLanguageModelV3({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
}),
|
||||
|
||||
/** Mock Anthropic Claude model */
|
||||
claude: new MockLanguageModelV3({
|
||||
provider: 'anthropic',
|
||||
modelId: 'claude-3-5-sonnet-20241022'
|
||||
}),
|
||||
|
||||
/** Mock Google Gemini model */
|
||||
gemini: new MockLanguageModelV3({
|
||||
provider: 'google',
|
||||
modelId: 'gemini-2.0-flash-exp'
|
||||
})
|
||||
} as const
|
||||
13
packages/aiCore/test_utils/index.ts
Normal file
13
packages/aiCore/test_utils/index.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* Test Infrastructure Exports
|
||||
* Central export point for all test utilities, fixtures, and helpers
|
||||
*/
|
||||
|
||||
// Mocks
|
||||
export * from './mocks/providers'
|
||||
export * from './mocks/responses'
|
||||
|
||||
// Helpers
|
||||
export * from './helpers/common'
|
||||
export * from './helpers/model'
|
||||
export * from './helpers/provider'
|
||||
@@ -11,11 +11,15 @@
|
||||
"noEmitOnError": false,
|
||||
"outDir": "./dist",
|
||||
"resolveJsonModule": true,
|
||||
"rootDir": "./src",
|
||||
"rootDir": ".",
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"target": "ES2020"
|
||||
"target": "ES2020",
|
||||
"paths": {
|
||||
"@test-utils": ["./test_utils"],
|
||||
"@test-utils/*": ["./test_utils/*"]
|
||||
}
|
||||
},
|
||||
"exclude": ["node_modules", "dist"],
|
||||
"include": ["src/**/*"]
|
||||
"include": ["src/**/*", "test_utils/**/*"]
|
||||
}
|
||||
|
||||
@@ -8,13 +8,14 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url))
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
|
||||
setupFiles: [path.resolve(__dirname, './test_utils/setup.ts')]
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
'@test-utils': path.resolve(__dirname, './test_utils'),
|
||||
// Mock external packages that may not be available in test environment
|
||||
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
|
||||
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './test_utils/mocks/ai-sdk-provider.ts')
|
||||
}
|
||||
},
|
||||
esbuild: {
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
diff --git a/client.js b/client.js
|
||||
index c2b9cd6e46f9f66f901af259661bc2d2f8b38936..9b6b3af1a6573e1ccaf3a1c5f41b48df198cbbe0 100644
|
||||
--- a/client.js
|
||||
+++ b/client.js
|
||||
@@ -26,7 +26,7 @@ Object.defineProperty(exports, "__esModule", { value: true });
|
||||
exports.AnthropicVertex = exports.BaseAnthropic = void 0;
|
||||
const client_1 = require("@anthropic-ai/sdk/client");
|
||||
const Resources = __importStar(require("@anthropic-ai/sdk/resources/index"));
|
||||
-const google_auth_library_1 = require("google-auth-library");
|
||||
+// const google_auth_library_1 = require("google-auth-library");
|
||||
const env_1 = require("./internal/utils/env.js");
|
||||
const values_1 = require("./internal/utils/values.js");
|
||||
const headers_1 = require("./internal/headers.js");
|
||||
@@ -56,7 +56,7 @@ class AnthropicVertex extends client_1.BaseAnthropic {
|
||||
throw new Error('No region was given. The client should be instantiated with the `region` option or the `CLOUD_ML_REGION` environment variable should be set.');
|
||||
}
|
||||
super({
|
||||
- baseURL: baseURL || `https://${region}-aiplatform.googleapis.com/v1`,
|
||||
+ baseURL: baseURL || (region === 'global' ? 'https://aiplatform.googleapis.com/v1' : `https://${region}-aiplatform.googleapis.com/v1`),
|
||||
...opts,
|
||||
});
|
||||
this.messages = makeMessagesResource(this);
|
||||
@@ -64,22 +64,22 @@ class AnthropicVertex extends client_1.BaseAnthropic {
|
||||
this.region = region;
|
||||
this.projectId = projectId;
|
||||
this.accessToken = opts.accessToken ?? null;
|
||||
- this._auth =
|
||||
- opts.googleAuth ?? new google_auth_library_1.GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
- this._authClientPromise = this._auth.getClient();
|
||||
+ // this._auth =
|
||||
+ // opts.googleAuth ?? new google_auth_library_1.GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
+ // this._authClientPromise = this._auth.getClient();
|
||||
}
|
||||
validateHeaders() {
|
||||
// auth validation is handled in prepareOptions since it needs to be async
|
||||
}
|
||||
- async prepareOptions(options) {
|
||||
- const authClient = await this._authClientPromise;
|
||||
- const authHeaders = await authClient.getRequestHeaders();
|
||||
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
- if (!this.projectId && projectId) {
|
||||
- this.projectId = projectId;
|
||||
- }
|
||||
- options.headers = (0, headers_1.buildHeaders)([authHeaders, options.headers]);
|
||||
- }
|
||||
+ // async prepareOptions(options) {
|
||||
+ // const authClient = await this._authClientPromise;
|
||||
+ // const authHeaders = await authClient.getRequestHeaders();
|
||||
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
+ // if (!this.projectId && projectId) {
|
||||
+ // this.projectId = projectId;
|
||||
+ // }
|
||||
+ // options.headers = (0, headers_1.buildHeaders)([authHeaders, options.headers]);
|
||||
+ // }
|
||||
buildRequest(options) {
|
||||
if ((0, values_1.isObj)(options.body)) {
|
||||
// create a shallow copy of the request body so that code that mutates it later
|
||||
diff --git a/client.mjs b/client.mjs
|
||||
index 70274cbf38f69f87cbcca9567e77e4a7b938cf90..4dea954b6f4afad565663426b7adfad5de973a7d 100644
|
||||
--- a/client.mjs
|
||||
+++ b/client.mjs
|
||||
@@ -1,6 +1,6 @@
|
||||
import { BaseAnthropic } from '@anthropic-ai/sdk/client';
|
||||
import * as Resources from '@anthropic-ai/sdk/resources/index';
|
||||
-import { GoogleAuth } from 'google-auth-library';
|
||||
+// import { GoogleAuth } from 'google-auth-library';
|
||||
import { readEnv } from "./internal/utils/env.mjs";
|
||||
import { isObj } from "./internal/utils/values.mjs";
|
||||
import { buildHeaders } from "./internal/headers.mjs";
|
||||
@@ -29,7 +29,7 @@ export class AnthropicVertex extends BaseAnthropic {
|
||||
throw new Error('No region was given. The client should be instantiated with the `region` option or the `CLOUD_ML_REGION` environment variable should be set.');
|
||||
}
|
||||
super({
|
||||
- baseURL: baseURL || `https://${region}-aiplatform.googleapis.com/v1`,
|
||||
+ baseURL: baseURL || (region === 'global' ? 'https://aiplatform.googleapis.com/v1' : `https://${region}-aiplatform.googleapis.com/v1`),
|
||||
...opts,
|
||||
});
|
||||
this.messages = makeMessagesResource(this);
|
||||
@@ -37,22 +37,22 @@ export class AnthropicVertex extends BaseAnthropic {
|
||||
this.region = region;
|
||||
this.projectId = projectId;
|
||||
this.accessToken = opts.accessToken ?? null;
|
||||
- this._auth =
|
||||
- opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
- this._authClientPromise = this._auth.getClient();
|
||||
+ // this._auth =
|
||||
+ // opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
+ //this._authClientPromise = this._auth.getClient();
|
||||
}
|
||||
validateHeaders() {
|
||||
// auth validation is handled in prepareOptions since it needs to be async
|
||||
}
|
||||
- async prepareOptions(options) {
|
||||
- const authClient = await this._authClientPromise;
|
||||
- const authHeaders = await authClient.getRequestHeaders();
|
||||
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
- if (!this.projectId && projectId) {
|
||||
- this.projectId = projectId;
|
||||
- }
|
||||
- options.headers = buildHeaders([authHeaders, options.headers]);
|
||||
- }
|
||||
+ // async prepareOptions(options) {
|
||||
+ // const authClient = await this._authClientPromise;
|
||||
+ // const authHeaders = await authClient.getRequestHeaders();
|
||||
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
+ // if (!this.projectId && projectId) {
|
||||
+ // this.projectId = projectId;
|
||||
+ // }
|
||||
+ // options.headers = buildHeaders([authHeaders, options.headers]);
|
||||
+ // }
|
||||
buildRequest(options) {
|
||||
if (isObj(options.body)) {
|
||||
// create a shallow copy of the request body so that code that mutates it later
|
||||
diff --git a/src/client.ts b/src/client.ts
|
||||
index a6f9c6be65e4189f4f9601fb560df3f68e7563eb..37b1ad2802e3ca0dae4ca35f9dcb5b22dcf09796 100644
|
||||
--- a/src/client.ts
|
||||
+++ b/src/client.ts
|
||||
@@ -12,22 +12,22 @@ export { BaseAnthropic } from '@anthropic-ai/sdk/client';
|
||||
const DEFAULT_VERSION = 'vertex-2023-10-16';
|
||||
const MODEL_ENDPOINTS = new Set<string>(['/v1/messages', '/v1/messages?beta=true']);
|
||||
|
||||
-export type ClientOptions = Omit<CoreClientOptions, 'apiKey' | 'authToken'> & {
|
||||
- region?: string | null | undefined;
|
||||
- projectId?: string | null | undefined;
|
||||
- accessToken?: string | null | undefined;
|
||||
-
|
||||
- /**
|
||||
- * Override the default google auth config using the
|
||||
- * [google-auth-library](https://www.npmjs.com/package/google-auth-library) package.
|
||||
- *
|
||||
- * Note that you'll likely have to set `scopes`, e.g.
|
||||
- * ```ts
|
||||
- * new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' })
|
||||
- * ```
|
||||
- */
|
||||
- googleAuth?: GoogleAuth | null | undefined;
|
||||
-};
|
||||
+// export type ClientOptions = Omit<CoreClientOptions, 'apiKey' | 'authToken'> & {
|
||||
+// region?: string | null | undefined;
|
||||
+// projectId?: string | null | undefined;
|
||||
+// accessToken?: string | null | undefined;
|
||||
+
|
||||
+// /**
|
||||
+// * Override the default google auth config using the
|
||||
+// * [google-auth-library](https://www.npmjs.com/package/google-auth-library) package.
|
||||
+// *
|
||||
+// * Note that you'll likely have to set `scopes`, e.g.
|
||||
+// * ```ts
|
||||
+// * new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' })
|
||||
+// * ```
|
||||
+// */
|
||||
+// googleAuth?: GoogleAuth | null | undefined;
|
||||
+// };
|
||||
|
||||
export class AnthropicVertex extends BaseAnthropic {
|
||||
region: string;
|
||||
@@ -74,9 +74,9 @@ export class AnthropicVertex extends BaseAnthropic {
|
||||
this.projectId = projectId;
|
||||
this.accessToken = opts.accessToken ?? null;
|
||||
|
||||
- this._auth =
|
||||
- opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
- this._authClientPromise = this._auth.getClient();
|
||||
+ // this._auth =
|
||||
+ // opts.googleAuth ?? new GoogleAuth({ scopes: 'https://www.googleapis.com/auth/cloud-platform' });
|
||||
+ // this._authClientPromise = this._auth.getClient();
|
||||
}
|
||||
|
||||
messages: MessagesResource = makeMessagesResource(this);
|
||||
@@ -86,17 +86,17 @@ export class AnthropicVertex extends BaseAnthropic {
|
||||
// auth validation is handled in prepareOptions since it needs to be async
|
||||
}
|
||||
|
||||
- protected override async prepareOptions(options: FinalRequestOptions): Promise<void> {
|
||||
- const authClient = await this._authClientPromise;
|
||||
+ // protected override async prepareOptions(options: FinalRequestOptions): Promise<void> {
|
||||
+ // const authClient = await this._authClientPromise;
|
||||
|
||||
- const authHeaders = await authClient.getRequestHeaders();
|
||||
- const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
- if (!this.projectId && projectId) {
|
||||
- this.projectId = projectId;
|
||||
- }
|
||||
+ // const authHeaders = await authClient.getRequestHeaders();
|
||||
+ // const projectId = authClient.projectId ?? authHeaders['x-goog-user-project'];
|
||||
+ // if (!this.projectId && projectId) {
|
||||
+ // this.projectId = projectId;
|
||||
+ // }
|
||||
|
||||
- options.headers = buildHeaders([authHeaders, options.headers]);
|
||||
- }
|
||||
+ // options.headers = buildHeaders([authHeaders, options.headers]);
|
||||
+ // }
|
||||
|
||||
override buildRequest(options: FinalRequestOptions): {
|
||||
req: FinalizedRequestInit;
|
||||
File diff suppressed because one or more lines are too long
154
patches/ai-npm-6.0.1-b73221ad63.patch
Normal file
154
patches/ai-npm-6.0.1-b73221ad63.patch
Normal file
@@ -0,0 +1,154 @@
|
||||
diff --git a/dist/index.d.mts b/dist/index.d.mts
|
||||
index 00a528a923dfd113b43bbbc99fd4218121690e2e..825db3a81a0e7c3e109c2226a8c74e58f29c3427 100644
|
||||
--- a/dist/index.d.mts
|
||||
+++ b/dist/index.d.mts
|
||||
@@ -4259,6 +4259,12 @@ declare function generateImage({ model: modelArg, prompt: promptArg, n, maxImage
|
||||
Only applicable for HTTP-based providers.
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
+ /**
|
||||
+ Custom download function to use for URLs.
|
||||
+
|
||||
+ By default, files are downloaded if the model returns URLs instead of binary data.
|
||||
+ */
|
||||
+ experimental_download?: DownloadFunction | undefined;
|
||||
}): Promise<GenerateImageResult>;
|
||||
|
||||
/**
|
||||
diff --git a/dist/index.d.ts b/dist/index.d.ts
|
||||
index 00a528a923dfd113b43bbbc99fd4218121690e2e..825db3a81a0e7c3e109c2226a8c74e58f29c3427 100644
|
||||
--- a/dist/index.d.ts
|
||||
+++ b/dist/index.d.ts
|
||||
@@ -4259,6 +4259,12 @@ declare function generateImage({ model: modelArg, prompt: promptArg, n, maxImage
|
||||
Only applicable for HTTP-based providers.
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
+ /**
|
||||
+ Custom download function to use for URLs.
|
||||
+
|
||||
+ By default, files are downloaded if the model returns URLs instead of binary data.
|
||||
+ */
|
||||
+ experimental_download?: DownloadFunction | undefined;
|
||||
}): Promise<GenerateImageResult>;
|
||||
|
||||
/**
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index b72259eb460c25f479cf9087a5c3c2e39bacf060..4d43da4202089bff9d774894888463f137f09358 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -8237,7 +8237,8 @@ async function generateImage({
|
||||
providerOptions,
|
||||
maxRetries: maxRetriesArg,
|
||||
abortSignal,
|
||||
- headers
|
||||
+ headers,
|
||||
+ experimental_download: download2
|
||||
}) {
|
||||
var _a14, _b;
|
||||
const model = resolveImageModel(modelArg);
|
||||
@@ -8286,21 +8287,33 @@ async function generateImage({
|
||||
outputTokens: void 0,
|
||||
totalTokens: void 0
|
||||
};
|
||||
+ const downloadFn = download2 ?? createDefaultDownloadFunction();
|
||||
for (const result of results) {
|
||||
- images.push(
|
||||
- ...result.images.map(
|
||||
- (image) => {
|
||||
- var _a15;
|
||||
- return new DefaultGeneratedFile({
|
||||
- data: image,
|
||||
- mediaType: (_a15 = detectMediaType({
|
||||
- data: image,
|
||||
- signatures: imageMediaTypeSignatures
|
||||
- })) != null ? _a15 : "image/png"
|
||||
- });
|
||||
+ const processedImages = await Promise.all(
|
||||
+ result.images.map(async (image) => {
|
||||
+ var _a15;
|
||||
+ // 检查是否为 URL 字符串
|
||||
+ if (typeof image === "string" && (image.startsWith("http://") || image.startsWith("https://"))) {
|
||||
+ const downloaded = await downloadFn([{ url: new URL(image), isUrlSupportedByModel: false }]);
|
||||
+ const downloadedData = downloaded[0];
|
||||
+ if (downloadedData) {
|
||||
+ return new DefaultGeneratedFile({
|
||||
+ data: downloadedData.data,
|
||||
+ mediaType: downloadedData.mediaType ?? "image/png"
|
||||
+ });
|
||||
+ }
|
||||
}
|
||||
- )
|
||||
+ // 原有逻辑:base64/Uint8Array
|
||||
+ return new DefaultGeneratedFile({
|
||||
+ data: image,
|
||||
+ mediaType: (_a15 = detectMediaType({
|
||||
+ data: image,
|
||||
+ signatures: imageMediaTypeSignatures
|
||||
+ })) != null ? _a15 : "image/png"
|
||||
+ });
|
||||
+ })
|
||||
);
|
||||
+ images.push(...processedImages);
|
||||
warnings.push(...result.warnings);
|
||||
if (result.usage != null) {
|
||||
totalUsage = addImageModelUsage(totalUsage, result.usage);
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index a6538d517173782802369ef4f90dac2a8cf9d75a..0977129233232d8bfcfa3c4936f5e9addeb8ca4a 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -8177,7 +8177,8 @@ async function generateImage({
|
||||
providerOptions,
|
||||
maxRetries: maxRetriesArg,
|
||||
abortSignal,
|
||||
- headers
|
||||
+ headers,
|
||||
+ experimental_download: download2
|
||||
}) {
|
||||
var _a14, _b;
|
||||
const model = resolveImageModel(modelArg);
|
||||
@@ -8226,21 +8227,33 @@ async function generateImage({
|
||||
outputTokens: void 0,
|
||||
totalTokens: void 0
|
||||
};
|
||||
+ const downloadFn = download2 ?? createDefaultDownloadFunction();
|
||||
for (const result of results) {
|
||||
- images.push(
|
||||
- ...result.images.map(
|
||||
- (image) => {
|
||||
- var _a15;
|
||||
- return new DefaultGeneratedFile({
|
||||
- data: image,
|
||||
- mediaType: (_a15 = detectMediaType({
|
||||
- data: image,
|
||||
- signatures: imageMediaTypeSignatures
|
||||
- })) != null ? _a15 : "image/png"
|
||||
- });
|
||||
+ const processedImages = await Promise.all(
|
||||
+ result.images.map(async (image) => {
|
||||
+ var _a15;
|
||||
+ // 检查是否为 URL 字符串
|
||||
+ if (typeof image === "string" && (image.startsWith("http://") || image.startsWith("https://"))) {
|
||||
+ const downloaded = await downloadFn([{ url: new URL(image), isUrlSupportedByModel: false }]);
|
||||
+ const downloadedData = downloaded[0];
|
||||
+ if (downloadedData) {
|
||||
+ return new DefaultGeneratedFile({
|
||||
+ data: downloadedData.data,
|
||||
+ mediaType: downloadedData.mediaType ?? "image/png"
|
||||
+ });
|
||||
+ }
|
||||
}
|
||||
- )
|
||||
+ // 原有逻辑:base64/Uint8Array
|
||||
+ return new DefaultGeneratedFile({
|
||||
+ data: image,
|
||||
+ mediaType: (_a15 = detectMediaType({
|
||||
+ data: image,
|
||||
+ signatures: imageMediaTypeSignatures
|
||||
+ })) != null ? _a15 : "image/png"
|
||||
+ });
|
||||
+ })
|
||||
);
|
||||
+ images.push(...processedImages);
|
||||
warnings.push(...result.warnings);
|
||||
if (result.usage != null) {
|
||||
totalUsage = addImageModelUsage(totalUsage, result.usage);
|
||||
1233
pnpm-lock.yaml
generated
1233
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,11 @@ import { SystemProviderIds } from '@types'
|
||||
|
||||
import { formatVertexApiHost } from './utils/api'
|
||||
|
||||
type HostFormatter = {
|
||||
match: (provider: Provider) => boolean
|
||||
format: (provider: Provider, appendApiVersion: boolean) => string | Promise<string>
|
||||
}
|
||||
|
||||
/**
|
||||
* Format and normalize the API host URL for a provider.
|
||||
* Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting).
|
||||
@@ -23,37 +28,40 @@ import { formatVertexApiHost } from './utils/api'
|
||||
* @returns A new provider instance with the formatted API host.
|
||||
*/
|
||||
export async function formatProviderApiHost(provider: Provider): Promise<Provider> {
|
||||
// WARNING: if any changes are made here, please sync it to src/renderer/src/aiCore/provider/providerConfig.ts:formatProviderApiHost
|
||||
// NOTE: It's async to support Vertex API host formatting
|
||||
const formatted = { ...provider }
|
||||
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
|
||||
|
||||
if (formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
|
||||
}
|
||||
|
||||
// Anthropic is special: uses anthropicApiHost as source and syncs both fields
|
||||
if (isAnthropicProvider(provider)) {
|
||||
const baseHost = formatted.anthropicApiHost || formatted.apiHost
|
||||
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
|
||||
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
|
||||
if (!formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatted.apiHost
|
||||
}
|
||||
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||
} else if (isOllamaProvider(formatted)) {
|
||||
formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
|
||||
} else if (isGeminiProvider(formatted)) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta')
|
||||
} else if (isAzureOpenAIProvider(formatted)) {
|
||||
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
||||
} else if (isVertexProvider(formatted)) {
|
||||
formatted.apiHost = await formatVertexApiHost(formatted.apiHost)
|
||||
} else if (isCherryAIProvider(formatted)) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||
} else if (isPerplexityProvider(formatted)) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||
} else {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion)
|
||||
return formatted
|
||||
}
|
||||
|
||||
const formatters: HostFormatter[] = [
|
||||
{
|
||||
match: (p) => p.id === SystemProviderIds.copilot || p.id === SystemProviderIds.github,
|
||||
format: (p) => formatApiHost(p.apiHost, false)
|
||||
},
|
||||
{ match: isCherryAIProvider, format: (p) => formatApiHost(p.apiHost, false) },
|
||||
{ match: isPerplexityProvider, format: (p) => formatApiHost(p.apiHost, false) },
|
||||
{ match: isOllamaProvider, format: (p) => formatOllamaApiHost(p.apiHost) },
|
||||
{ match: isGeminiProvider, format: (p, av) => formatApiHost(p.apiHost, av, 'v1beta') },
|
||||
{ match: isAzureOpenAIProvider, format: (p) => formatAzureOpenAIApiHost(p.apiHost) },
|
||||
{ match: isVertexProvider, format: (p) => formatVertexApiHost(p.apiHost) }
|
||||
]
|
||||
|
||||
const formatter = formatters.find((f) => f.match(provider))
|
||||
formatted.apiHost = formatter
|
||||
? await formatter.format(formatted, appendApiVersion)
|
||||
: formatApiHost(formatted.apiHost, appendApiVersion)
|
||||
|
||||
return formatted
|
||||
}
|
||||
|
||||
@@ -671,9 +671,10 @@ class FileStorage {
|
||||
|
||||
const parseResult = parseDataUrl(base64Data)
|
||||
const base64String = parseResult?.data ?? base64Data
|
||||
const ext = parseResult?.mediaType ? this.getExtensionFromMimeType(parseResult.mediaType) : '.png'
|
||||
|
||||
const buffer = Buffer.from(base64String, 'base64')
|
||||
const uuid = uuidv4()
|
||||
const ext = '.png'
|
||||
const destPath = path.join(this.storageDir, uuid + ext)
|
||||
|
||||
logger.debug('Saving base64 image:', {
|
||||
@@ -1560,6 +1561,8 @@ class FileStorage {
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
'image/gif': '.gif',
|
||||
'image/webp': '.webp',
|
||||
'image/bmp': '.bmp',
|
||||
'application/pdf': '.pdf',
|
||||
'text/plain': '.txt',
|
||||
'application/msword': '.doc',
|
||||
|
||||
514
src/renderer/src/aiCore/AiProvider.ts
Normal file
514
src/renderer/src/aiCore/AiProvider.ts
Normal file
@@ -0,0 +1,514 @@
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import type { generateImageResult } from '@cherrystudio/ai-core/core/runtime/types'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { normalizeGatewayModels } from '@renderer/services/models/ModelAdapter'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import {
|
||||
type Assistant,
|
||||
type EditImageParams,
|
||||
type GenerateImageParams,
|
||||
type Model,
|
||||
type Provider,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
|
||||
import { gateway } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { adaptProvider, getActualProvider, providerToAiSdkConfig } from './provider/providerConfig'
|
||||
import { listModels } from './services/listModels'
|
||||
import type { AppProviderSettingsMap, CompletionsResult, ProviderConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export type AiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
|
||||
export default class AiProvider {
|
||||
private config?: ProviderConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
|
||||
/**
|
||||
* Constructor for AiProvider
|
||||
*
|
||||
* @param modelOrProvider - Model or Provider object
|
||||
* @param provider - Optional Provider object (only used when first param is Model)
|
||||
*
|
||||
* @remarks
|
||||
* **Important behavior notes**:
|
||||
*
|
||||
* 1. When called with `(model)`:
|
||||
* - Calls `getActualProvider(model)` to retrieve and format the provider
|
||||
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
|
||||
*
|
||||
* 2. When called with `(model, provider)`:
|
||||
* - The provided provider will be adapted via `adaptProvider`
|
||||
* - URL formatting behavior depends on the adapted result
|
||||
*
|
||||
* 3. When called with `(provider)`:
|
||||
* - The provider will be adapted via `adaptProvider`
|
||||
* - Used for operations that don't need a model (e.g., fetchModels)
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Recommended: Auto-format URL
|
||||
* const ai = new AiProvider(model)
|
||||
*
|
||||
* // Provider will be adapted
|
||||
* const ai = new AiProvider(model, customProvider)
|
||||
*
|
||||
* // For operations that don't need a model
|
||||
* const ai = new AiProvider(provider)
|
||||
* ```
|
||||
*/
|
||||
constructor(model: Model, provider?: Provider)
|
||||
constructor(provider: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
|
||||
if (this.isModel(modelOrProvider)) {
|
||||
// 传入的是 Model
|
||||
this.model = modelOrProvider
|
||||
this.actualProvider = provider
|
||||
? adaptProvider({ provider, model: modelOrProvider })
|
||||
: getActualProvider(modelOrProvider)
|
||||
// 注意:config 可能是同步值或 Promise,在 completions() 中会统一处理
|
||||
const configOrPromise = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||
this.config = configOrPromise instanceof Promise ? undefined : configOrPromise
|
||||
} else {
|
||||
// 传入的是 Provider
|
||||
this.actualProvider = adaptProvider({ provider: modelOrProvider })
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
|
||||
*/
|
||||
private isModel(obj: Model | Provider): obj is Model {
|
||||
return 'provider' in obj && typeof obj.provider === 'string'
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
public async completions(modelId: string, params: StreamTextParams, middlewareConfig: AiProviderConfig) {
|
||||
// 检查model是否存在
|
||||
if (!this.model) {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// Config is now set in constructor, ApiService handles key rotation before passing provider
|
||||
if (!this.config) {
|
||||
// If config wasn't set in constructor (when provider only), generate it now
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, this.model))
|
||||
}
|
||||
logger.debug('Using provider config for completions', this.config)
|
||||
|
||||
// 注意:模型对象将由 createExecutor 内部处理,不再需要预先创建
|
||||
|
||||
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
|
||||
// 类型守卫:确保 system 是 string、Array 或 undefined
|
||||
const system = params.system
|
||||
let systemParam: string | Array<any> | undefined
|
||||
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
|
||||
systemParam = system
|
||||
} else {
|
||||
// SystemModelMessage 类型,转换为 string
|
||||
systemParam = undefined
|
||||
}
|
||||
|
||||
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
|
||||
params.system = undefined // 清除原有system,避免重复
|
||||
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
|
||||
}
|
||||
|
||||
if (middlewareConfig.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...middlewareConfig,
|
||||
topicId: middlewareConfig.topicId
|
||||
}
|
||||
return await this._completionsForTrace(modelId, params, traceConfig, this.config)
|
||||
} else {
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig, this.config)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
private async _completionsForTrace(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiProviderConfig & { topicId: string },
|
||||
providerConfig: ProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${middlewareConfig.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelName: middlewareConfig.assistant.model?.name, // 使用modelId而不是provider名称
|
||||
inputs: params
|
||||
}
|
||||
|
||||
logger.info('Starting AI SDK trace span', {
|
||||
traceName,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : []
|
||||
})
|
||||
|
||||
const span = addSpan(traceParams)
|
||||
if (!span) {
|
||||
logger.warn('Failed to create span, falling back to regular completions', {
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Created parent span, now calling completions', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this.modernCompletions(modelId, params, middlewareConfig, providerConfig)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId,
|
||||
resultLength: result.getText().length
|
||||
})
|
||||
|
||||
// 标记span完成
|
||||
endSpan({
|
||||
topicId: middlewareConfig.topicId,
|
||||
outputs: result,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: middlewareConfig.topicId,
|
||||
modelId
|
||||
})
|
||||
|
||||
// 标记span出错
|
||||
endSpan({
|
||||
topicId: middlewareConfig.topicId,
|
||||
error: error as Error,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
/**
|
||||
* Note: This implementation always uses `executor.streamText` and never
|
||||
* calls `generateText`, even when `onChunk` is not provided.
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiProviderConfig,
|
||||
providerConfig: ProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const plugins = buildPlugins({
|
||||
provider: this.actualProvider,
|
||||
model: this.model!,
|
||||
config: middlewareConfig
|
||||
})
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
plugins
|
||||
)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (middlewareConfig.onChunk) {
|
||||
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(
|
||||
middlewareConfig.onChunk,
|
||||
middlewareConfig.mcpTools,
|
||||
accumulate,
|
||||
middlewareConfig.enableWebSearch,
|
||||
undefined,
|
||||
undefined,
|
||||
providerConfig.providerId
|
||||
)
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model: modelId,
|
||||
experimental_context: { onChunk: middlewareConfig.onChunk }
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// Since no onChunk is provided, the external consumer would not handle error chunk.
|
||||
// So we need to capture the actual stream error so we can throw it instead of the
|
||||
// generic NoTextGeneratedError ("No output generated. Check the stream
|
||||
// for errors.") that AI SDK raises when streamResult.text is accessed
|
||||
// after a failed stream.
|
||||
let streamError: unknown = undefined
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model: modelId,
|
||||
onError({ error }) {
|
||||
streamError = error
|
||||
}
|
||||
})
|
||||
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream({
|
||||
onError(error) {
|
||||
if (!streamError) {
|
||||
streamError = error
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
try {
|
||||
const finalText = await streamResult.text
|
||||
const usage = await streamResult.totalUsage
|
||||
|
||||
return {
|
||||
getText: () => finalText,
|
||||
usage
|
||||
}
|
||||
} catch (error) {
|
||||
// If we captured the real stream error, throw that instead of the
|
||||
// generic NoTextGeneratedError so callers get actionable diagnostics.
|
||||
if (streamError) {
|
||||
throw streamError
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型列表
|
||||
* 使用 ModelListService 统一处理各 Provider 的模型列表获取
|
||||
*/
|
||||
public async models(): Promise<Model[]> {
|
||||
// Gateway provider 使用 AI SDK 的 gateway API
|
||||
if (this.actualProvider.id === SystemProviderIds.gateway) {
|
||||
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||
}
|
||||
|
||||
// 使用新的 ModelListService
|
||||
return await listModels(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取嵌入模型的维度
|
||||
* 使用 AI SDK embedMany 测试获取维度
|
||||
*/
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
|
||||
}
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
this.config.providerId,
|
||||
this.config.providerSettings,
|
||||
[]
|
||||
)
|
||||
|
||||
// 使用 AI SDK embedMany 测试获取维度
|
||||
const result = await executor.embedMany({
|
||||
model: model.id,
|
||||
values: ['test']
|
||||
})
|
||||
|
||||
return result.embeddings[0].length
|
||||
}
|
||||
|
||||
/**
|
||||
* 懒加载初始化 config
|
||||
* 当 constructor 只传入 provider 时,config 不会被初始化
|
||||
* 此方法根据 modelId 从 provider 的 models 中查找真实 Model 并生成 config
|
||||
*/
|
||||
private async ensureConfig(modelId: string): Promise<void> {
|
||||
if (this.config) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从 provider 的 models 中查找真实的 model
|
||||
const model = this.actualProvider.models.find((m) => getLowerBaseModelName(m.id) === getLowerBaseModelName(modelId))
|
||||
if (!model) {
|
||||
throw new Error(`Model "${modelId}" not found in provider "${this.actualProvider.id}"`)
|
||||
}
|
||||
|
||||
this.actualProvider = adaptProvider({ provider: this.actualProvider, model })
|
||||
this.config = await Promise.resolve(providerToAiSdkConfig(this.actualProvider, model))
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
* 使用现代化 AI SDK 实现,不再 fallback 到 legacy
|
||||
*/
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
await this.ensureConfig(params.model)
|
||||
return await this.modernGenerateImage(params, this.config!)
|
||||
}
|
||||
|
||||
/**
|
||||
* 编辑图像 - 基于输入图像和文本提示生成新图像
|
||||
* 内部使用 AI SDK 的 generateImage,通过 prompt.images 参数实现编辑功能
|
||||
*/
|
||||
public async editImage(params: EditImageParams): Promise<string[]> {
|
||||
await this.ensureConfig(params.model)
|
||||
return await this.modernEditImage(params, this.config!)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams, providerConfig: ProviderConfig): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
[]
|
||||
)
|
||||
const result = await executor.generateImage({
|
||||
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
|
||||
...aiSdkParams
|
||||
})
|
||||
|
||||
return this.convertImageResult(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像编辑实现
|
||||
* 通过 AI SDK 的 generateImage 并传入 prompt.images 参数实现编辑功能
|
||||
*/
|
||||
private async modernEditImage(params: EditImageParams, providerConfig: ProviderConfig): Promise<string[]> {
|
||||
const { model, prompt, inputImages, mask, imageSize, signal } = params
|
||||
|
||||
const executor = await createExecutor<AppProviderSettingsMap>(
|
||||
providerConfig.providerId,
|
||||
providerConfig.providerSettings,
|
||||
[]
|
||||
)
|
||||
|
||||
// 使用 AI SDK 的 generateImage,通过 prompt.images 实现编辑
|
||||
const result = await executor.generateImage({
|
||||
model: model,
|
||||
prompt: {
|
||||
text: prompt,
|
||||
images: inputImages, // 输入图像(必需)
|
||||
...(mask && { mask }) // 可选的 mask(用于 inpainting)
|
||||
},
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
...(signal && { abortSignal: signal })
|
||||
})
|
||||
|
||||
return this.convertImageResult(result)
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换图像生成结果格式
|
||||
*/
|
||||
private convertImageResult(result: generateImageResult): string[] {
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if (image.base64) {
|
||||
images.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.actualProvider.apiHost || ''
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
const apiKey = this.actualProvider.apiKey
|
||||
if (!apiKey || apiKey.trim() === '') {
|
||||
return ''
|
||||
}
|
||||
|
||||
const keys = apiKey
|
||||
.split(',')
|
||||
.map((key) => key.trim())
|
||||
.filter(Boolean)
|
||||
|
||||
if (keys.length === 0) {
|
||||
return ''
|
||||
}
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
// Multi-key rotation
|
||||
const keyName = `provider:${this.actualProvider.id}:last_used_key`
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
}
|
||||
@@ -306,25 +306,6 @@ export class AiSdkToChunkAdapter {
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
// })
|
||||
// break
|
||||
// case 'step-finish':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.TEXT_COMPLETE,
|
||||
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
// })
|
||||
// final.text = ''
|
||||
// break
|
||||
|
||||
case 'finish-step': {
|
||||
const { providerMetadata, finishReason } = chunk
|
||||
// googel web search
|
||||
|
||||
@@ -1,16 +1 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 统一入口点
|
||||
*
|
||||
* 这是新的统一入口,保持向后兼容性
|
||||
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||
*/
|
||||
|
||||
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||
export { default } from './legacy/index'
|
||||
|
||||
// 同时导出Modern AiProvider供新代码使用
|
||||
export { default as ModernAiProvider } from './index_new'
|
||||
|
||||
// 导出一些常用的类型和工具
|
||||
export * from './legacy/clients/types'
|
||||
export * from './legacy/middleware/schemas'
|
||||
export { default as AiProvider, type AiProviderConfig } from './AiProvider'
|
||||
|
||||
@@ -1,609 +0,0 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 新版本入口
|
||||
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
|
||||
*
|
||||
* 融合方案:简化实现,专注于核心功能
|
||||
* 1. 优先使用新AI SDK
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import type { AiSdkModel } from '@cherrystudio/ai-core'
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { SUPPORTED_IMAGE_ENDPOINT_LIST } from '@renderer/utils'
|
||||
import type { IdleTimeoutHandle } from '@renderer/utils/IdleTimeoutController'
|
||||
import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic'
|
||||
import { gateway, type LanguageModel, type Provider as AiSdkProvider } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import type { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
adaptProvider,
|
||||
getActualProvider,
|
||||
isModernSdkSupported,
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { AiSdkConfig } from './types'
|
||||
import type { AiSdkMiddlewareConfig } from './types/middlewareConfig'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
idleTimeout?: IdleTimeoutHandle
|
||||
}
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: AiSdkConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
|
||||
/**
|
||||
* Constructor for ModernAiProvider
|
||||
*
|
||||
* @param modelOrProvider - Model or Provider object
|
||||
* @param provider - Optional Provider object (only used when first param is Model)
|
||||
*
|
||||
* @remarks
|
||||
* **Important behavior notes**:
|
||||
*
|
||||
* 1. When called with `(model)`:
|
||||
* - Calls `getActualProvider(model)` to retrieve and format the provider
|
||||
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
|
||||
*
|
||||
* 2. When called with `(model, provider)`:
|
||||
* - The provided provider will be adapted via `adaptProvider`
|
||||
* - URL formatting behavior depends on the adapted result
|
||||
*
|
||||
* 3. When called with `(provider)`:
|
||||
* - The provider will be adapted via `adaptProvider`
|
||||
* - Used for operations that don't need a model (e.g., fetchModels)
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Recommended: Auto-format URL
|
||||
* const ai = new ModernAiProvider(model)
|
||||
*
|
||||
* // Provider will be adapted
|
||||
* const ai = new ModernAiProvider(model, customProvider)
|
||||
*
|
||||
* // For operations that don't need a model
|
||||
* const ai = new ModernAiProvider(provider)
|
||||
* ```
|
||||
*/
|
||||
constructor(model: Model, provider?: Provider)
|
||||
constructor(provider: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
|
||||
if (this.isModel(modelOrProvider)) {
|
||||
// 传入的是 Model
|
||||
this.model = modelOrProvider
|
||||
this.actualProvider = provider
|
||||
? adaptProvider({ provider, model: modelOrProvider })
|
||||
: getActualProvider(modelOrProvider)
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||
} else {
|
||||
// 传入的是 Provider
|
||||
this.actualProvider = adaptProvider({ provider: modelOrProvider })
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
|
||||
*/
|
||||
private isModel(obj: Model | Provider): obj is Model {
|
||||
return 'provider' in obj && typeof obj.provider === 'string'
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* Note: This method routes text completions through `modernCompletions`,
|
||||
* which only calls `streamText` (no `generateText` path).
|
||||
*/
|
||||
public async completions(modelId: string, params: StreamTextParams, providerConfig: ModernAiProviderConfig) {
|
||||
// 检查model是否存在
|
||||
if (!this.model) {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// Config is now set in constructor, ApiService handles key rotation before passing provider
|
||||
if (!this.config) {
|
||||
// If config wasn't set in constructor (when provider only), generate it now
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
}
|
||||
logger.debug('Using provider config for completions', this.config)
|
||||
|
||||
// 检查 config 是否存在
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||
}
|
||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||
providerConfig.isImageGenerationEndpoint = true
|
||||
}
|
||||
// 准备特殊配置
|
||||
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||
|
||||
// 提前创建本地 provider 实例
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
}
|
||||
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
|
||||
// 根据endpoint类型创建对应的模型
|
||||
let model: AiSdkModel | undefined
|
||||
if (providerConfig.isImageGenerationEndpoint) {
|
||||
model = this.localProvider.imageModel(modelId)
|
||||
} else {
|
||||
model = this.localProvider.languageModel(modelId)
|
||||
}
|
||||
|
||||
if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') {
|
||||
// 类型守卫:确保 system 是 string、Array 或 undefined
|
||||
const system = params.system
|
||||
let systemParam: string | Array<any> | undefined
|
||||
if (typeof system === 'string' || Array.isArray(system) || system === undefined) {
|
||||
systemParam = system
|
||||
} else {
|
||||
// SystemModelMessage 类型,转换为 string
|
||||
systemParam = undefined
|
||||
}
|
||||
|
||||
const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(systemParam)
|
||||
params.system = undefined // 清除原有system,避免重复
|
||||
params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])]
|
||||
}
|
||||
|
||||
if (providerConfig.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...providerConfig,
|
||||
topicId: providerConfig.topicId
|
||||
}
|
||||
return await this._completionsForTrace(model, params, traceConfig)
|
||||
} else {
|
||||
return await this._completionsOrImageGeneration(model, params, providerConfig)
|
||||
}
|
||||
}
|
||||
|
||||
private async _completionsOrImageGeneration(
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// ai-gateway不是image/generation 端点,所以就先不走legacy了
|
||||
if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!config.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
}
|
||||
|
||||
const legacyParams: CompletionsParams = {
|
||||
callType: 'chat',
|
||||
messages: config.uiMessages, // 使用原始的 UI 消息格式
|
||||
assistant: config.assistant,
|
||||
streamOutput: config.streamOutput ?? true,
|
||||
onChunk: config.onChunk,
|
||||
topicId: config.topicId,
|
||||
mcpTools: config.mcpTools,
|
||||
enableWebSearch: config.enableWebSearch
|
||||
}
|
||||
|
||||
// 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware
|
||||
return await this.legacyProvider.completions(legacyParams)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(model as LanguageModel, params, config)
|
||||
}
|
||||
|
||||
/**
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
private async _completionsForTrace(
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig & { topicId: string }
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model!.id
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: config.topicId,
|
||||
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
|
||||
inputs: params
|
||||
}
|
||||
|
||||
logger.info('Starting AI SDK trace span', {
|
||||
traceName,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||
isImageGeneration: config.isImageGenerationEndpoint
|
||||
})
|
||||
|
||||
const span = addSpan(traceParams)
|
||||
if (!span) {
|
||||
logger.warn('Failed to create span, falling back to regular completions', {
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this._completionsOrImageGeneration(model, params, config)
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Created parent span, now calling completions', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this._completionsOrImageGeneration(model, params, config)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
resultLength: result.getText().length
|
||||
})
|
||||
|
||||
// 标记span完成
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
outputs: result,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId
|
||||
})
|
||||
|
||||
// 标记span出错
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
error: error as Error,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
/**
|
||||
* Note: This implementation always uses `executor.streamText` and never
|
||||
* calls `generateText`, even when `onChunk` is not provided.
|
||||
*/
|
||||
private async modernCompletions(
|
||||
model: LanguageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// 根据条件构建插件数组
|
||||
const plugins = buildPlugins({
|
||||
provider: this.actualProvider,
|
||||
model: this.model!,
|
||||
config
|
||||
})
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(
|
||||
config.onChunk,
|
||||
config.mcpTools,
|
||||
accumulate,
|
||||
config.enableWebSearch,
|
||||
undefined,
|
||||
undefined,
|
||||
this.config!.providerId,
|
||||
config.idleTimeout
|
||||
)
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model,
|
||||
experimental_context: { onChunk: config.onChunk }
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// Since no onChunk is provided, the external consumer would not handle error chunk.
|
||||
// So we need to capture the actual stream error so we can throw it instead of the
|
||||
// generic NoTextGeneratedError ("No output generated. Check the stream
|
||||
// for errors.") that AI SDK raises when streamResult.text is accessed
|
||||
// after a failed stream.
|
||||
let streamError: unknown = undefined
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model,
|
||||
onError({ error }) {
|
||||
streamError = error
|
||||
}
|
||||
})
|
||||
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream({
|
||||
onError(error) {
|
||||
if (!streamError) {
|
||||
streamError = error
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
try {
|
||||
const finalText = await streamResult.text
|
||||
const usage = await streamResult.totalUsage
|
||||
|
||||
return {
|
||||
getText: () => finalText,
|
||||
usage
|
||||
}
|
||||
} catch (error) {
|
||||
// If we captured the real stream error, throw that instead of the
|
||||
// generic NoTextGeneratedError so callers get actionable diagnostics.
|
||||
if (streamError) {
|
||||
throw streamError
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
// * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
|
||||
// */
|
||||
/*
|
||||
private async modernImageGeneration(
|
||||
model: ImageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const { onChunk } = config
|
||||
|
||||
try {
|
||||
// 检查 messages 是否存在
|
||||
if (!params.messages || params.messages.length === 0) {
|
||||
throw new Error('No messages provided for image generation.')
|
||||
}
|
||||
|
||||
// 从最后一条用户消息中提取 prompt
|
||||
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
// 直接使用消息内容,避免类型转换问题
|
||||
const prompt =
|
||||
typeof lastUserMessage.content === 'string'
|
||||
? lastUserMessage.content
|
||||
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('No prompt found in user message.')
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 发送图像生成开始事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||
}
|
||||
|
||||
// 构建图像生成参数
|
||||
const imageParams = {
|
||||
prompt,
|
||||
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||
n: 1,
|
||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model,
|
||||
...imageParams
|
||||
})
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
const imageType: 'url' | 'base64' = 'base64'
|
||||
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送图像生成完成事件
|
||||
if (onChunk && images.length > 0) {
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images }
|
||||
})
|
||||
}
|
||||
|
||||
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||
if (onChunk) {
|
||||
const usage = {
|
||||
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||
total_tokens: prompt.length
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 发送 LLM 响应完成事件
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => '' // 图像生成不返回文本
|
||||
}
|
||||
} catch (error) {
|
||||
// 发送错误事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
if (this.actualProvider.id === SystemProviderIds.gateway) {
|
||||
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||
}
|
||||
const sdkModels = await this.legacyProvider.models()
|
||||
return normalizeSdkModels(this.actualProvider, sdkModels)
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with generateImage')
|
||||
}
|
||||
|
||||
// 确保本地provider已创建
|
||||
if (!this.localProvider && this.config) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
}
|
||||
|
||||
const result = await this.modernGenerateImage(params)
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||
// fallback 到传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接使用传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model: model, // 直接使用 model ID 字符串,由 executor 内部解析
|
||||
...aiSdkParams
|
||||
})
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.legacyProvider.getApiKey()
|
||||
}
|
||||
}
|
||||
|
||||
// 为了方便调试,导出一些工具函数
|
||||
export { isModernSdkSupported, providerToAiSdkConfig }
|
||||
@@ -1,110 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient'
|
||||
import type { BaseApiClient } from './BaseApiClient'
|
||||
import { CherryAiAPIClient } from './cherryai/CherryAiAPIClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from './gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { OVMSClient } from './ovms/OVMSClient'
|
||||
import { PoeAPIClient } from './poe/PoeAPIClient'
|
||||
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
|
||||
import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('ApiClientFactory')
|
||||
|
||||
/**
|
||||
* Factory for creating ApiClient instances based on provider configuration
|
||||
* 根据提供者配置创建ApiClient实例的工厂
|
||||
*/
|
||||
export class ApiClientFactory {
|
||||
/**
|
||||
* Create an ApiClient instance for the given provider
|
||||
* 为给定的提供者创建ApiClient实例
|
||||
*/
|
||||
static create(provider: Provider): BaseApiClient {
|
||||
logger.debug(`Creating ApiClient for provider:`, {
|
||||
id: provider.id,
|
||||
type: provider.type
|
||||
})
|
||||
|
||||
let instance: BaseApiClient
|
||||
|
||||
// 首先检查特殊的 Provider ID
|
||||
if (provider.id === 'cherryai') {
|
||||
instance = new CherryAiAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
logger.debug(`Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (isNewApiProvider(provider)) {
|
||||
logger.debug(`Creating NewAPIClient for provider: ${provider.id}`)
|
||||
instance = new NewAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'ppio') {
|
||||
logger.debug(`Creating PPIOAPIClient for provider: ${provider.id}`)
|
||||
instance = new PPIOAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'zhipu') {
|
||||
instance = new ZhipuAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'ovms') {
|
||||
logger.debug(`Creating OVMSClient for provider: ${provider.id}`)
|
||||
instance = new OVMSClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
if (provider.id === 'poe') {
|
||||
logger.debug(`Creating PoeAPIClient for provider: ${provider.id}`)
|
||||
instance = new PoeAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
// 然后检查标准的 Provider Type
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'azure-openai':
|
||||
case 'openai-response':
|
||||
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'gemini':
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'aws-bedrock':
|
||||
instance = new AwsBedrockAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
}
|
||||
@@ -1,489 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
getModelSupportedVerbosity,
|
||||
isFunctionCallingModel,
|
||||
isOpenAIModel,
|
||||
isSupportFlexServiceTierModel,
|
||||
isSupportTemperatureModel,
|
||||
isSupportTopPModel
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import type { RootState } from '@renderer/store'
|
||||
import type {
|
||||
Assistant,
|
||||
GenerateImageParams,
|
||||
KnowledgeReference,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
MemoryItem,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
WebSearchResponse
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
FILE_TYPE,
|
||||
GroqServiceTiers,
|
||||
isGroqServiceTier,
|
||||
isOpenAIServiceTier,
|
||||
OpenAIServiceTiers,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||
import { DEFAULT_TIMEOUT } from '@shared/config/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import type { CompletionsContext } from '../middleware/types'
|
||||
import type { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
const logger = loggerService.withContext('BaseApiClient')
|
||||
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
* Provides common functionality and structure for specific client implementations.
|
||||
*/
|
||||
export abstract class BaseApiClient<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||
{
|
||||
public provider: Provider
|
||||
protected host: string
|
||||
protected sdkInstance?: TSdkInstance
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.host = this.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current API key with rotation support
|
||||
* This getter ensures API keys rotate on each access when multiple keys are configured
|
||||
*/
|
||||
protected get apiKey(): string {
|
||||
return this.getApiKey()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端的兼容性类型
|
||||
* 用于判断客户端是否支持特定功能,避免instanceof检查的类型收窄问题
|
||||
* 对于装饰器模式的客户端(如AihubmixAPIClient),应该返回其内部实际使用的客户端类型
|
||||
*/
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
public getClientCompatibilityType(_model?: Model): string[] {
|
||||
// 默认返回类的名称
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||
|
||||
/**
|
||||
* 核心API Endpoint
|
||||
**/
|
||||
|
||||
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||
|
||||
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||
|
||||
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||
|
||||
abstract listModels(): Promise<SdkModel[]>
|
||||
|
||||
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
|
||||
/**
|
||||
* 中间件
|
||||
**/
|
||||
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
**/
|
||||
|
||||
// Optional tool conversion methods - implement if needed by the specific provider
|
||||
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
|
||||
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
|
||||
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
abstract buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string | undefined,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
abstract estimateMessageTokens(message: TMessageParam): number
|
||||
|
||||
abstract convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
|
||||
/**
|
||||
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
const keys = this.provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${this.provider.id}:last_used_key`
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
...defaultAppHeaders(),
|
||||
'X-Api-Key': this.apiKey
|
||||
}
|
||||
}
|
||||
|
||||
public get keepAliveTime() {
|
||||
return this.provider.id === 'lmstudio' ? getLMStudioKeepAliveTime() : undefined
|
||||
}
|
||||
|
||||
public getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (!isSupportTemperatureModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
|
||||
}
|
||||
|
||||
public getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (!isSupportTopPModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
return assistantSettings?.enableTopP ? assistantSettings?.topP : undefined
|
||||
}
|
||||
|
||||
// NOTE: 这个也许可以迁移到OpenAIBaseClient
|
||||
protected getServiceTier(model: Model) {
|
||||
const serviceTierSetting = this.provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(this.provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (this.provider.id === SystemProviderIds.groq) {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
return serviceTierSetting
|
||||
}
|
||||
|
||||
protected getVerbosity(model?: Model): OpenAIVerbosity {
|
||||
try {
|
||||
const state = window.store?.getState() as RootState
|
||||
const verbosity = state?.settings?.openAI?.verbosity
|
||||
|
||||
// If model is provided, check if the verbosity is supported by the model
|
||||
if (model) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
// Use user's verbosity if supported, otherwise use the first supported option
|
||||
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
|
||||
}
|
||||
return verbosity
|
||||
} catch (error) {
|
||||
logger.warn('Failed to get verbosity from state. Fallback to undefined.', error as Error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
protected getTimeout(model: Model) {
|
||||
if (isSupportFlexServiceTierModel(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return DEFAULT_TIMEOUT
|
||||
}
|
||||
|
||||
public async getMessageContent(
|
||||
message: Message
|
||||
): Promise<{ textContent: string; imageContents: { fileId: string; fileExt: string }[] }> {
|
||||
const content = getMainTextContent(message)
|
||||
|
||||
if (isEmpty(content)) {
|
||||
return {
|
||||
textContent: '',
|
||||
imageContents: []
|
||||
}
|
||||
}
|
||||
|
||||
const webSearchReferences = await this.getWebSearchReferencesFromCache(message)
|
||||
const knowledgeReferences = await this.getKnowledgeBaseReferencesFromCache(message)
|
||||
const memoryReferences = this.getMemoryReferencesFromCache(message)
|
||||
|
||||
const knowledgeTextReferences = knowledgeReferences.filter((k) => k.metadata?.type !== 'image')
|
||||
const knowledgeImageReferences = knowledgeReferences.filter((k) => k.metadata?.type === 'image')
|
||||
|
||||
// 添加偏移量以避免ID冲突
|
||||
const reindexedKnowledgeReferences = knowledgeTextReferences.map((ref) => ({
|
||||
...ref,
|
||||
id: ref.id + webSearchReferences.length // 为知识库引用的ID添加网络搜索引用的数量作为偏移量
|
||||
}))
|
||||
|
||||
const allReferences = [...webSearchReferences, ...reindexedKnowledgeReferences, ...memoryReferences]
|
||||
|
||||
logger.debug(`Found ${allReferences.length} references for ID: ${message.id}`, allReferences)
|
||||
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(allReferences, null, 2)}\n\`\`\``
|
||||
const imageReferences = knowledgeImageReferences.map((r) => {
|
||||
return { fileId: r.metadata?.id, fileExt: r.metadata?.ext }
|
||||
})
|
||||
|
||||
return {
|
||||
textContent: isEmpty(allReferences)
|
||||
? content
|
||||
: REFERENCE_PROMPT.replace('{question}', content).replace('{references}', referenceContent),
|
||||
imageContents: isEmpty(knowledgeImageReferences) ? [] : imageReferences
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the file content from the message
|
||||
* @param message - The message
|
||||
* @returns The file content
|
||||
*/
|
||||
protected async extractFileContent(message: Message) {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => fb.file.type === type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
private getMemoryReferencesFromCache(message: Message) {
|
||||
const memories = window.keyv.get(`memory-search-${message.id}`) as MemoryItem[] | undefined
|
||||
if (memories) {
|
||||
const memoryReferences: KnowledgeReference[] = memories.map((mem, index) => ({
|
||||
id: index + 1,
|
||||
content: `${mem.memory} -- Created at: ${mem.createdAt}`,
|
||||
sourceUrl: '',
|
||||
type: 'memory'
|
||||
}))
|
||||
return memoryReferences
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
||||
private async getWebSearchReferencesFromCache(message: Message) {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
return []
|
||||
}
|
||||
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
|
||||
|
||||
if (webSearch) {
|
||||
window.keyv.remove(`web-search-${message.id}`)
|
||||
return (webSearch.results as WebSearchProviderResponse).results.map(
|
||||
(result, index) =>
|
||||
({
|
||||
id: index + 1,
|
||||
content: result.content,
|
||||
sourceUrl: result.url,
|
||||
type: 'url'
|
||||
}) as KnowledgeReference
|
||||
)
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* 从缓存中获取知识库引用
|
||||
*/
|
||||
private async getKnowledgeBaseReferencesFromCache(message: Message): Promise<KnowledgeReference[]> {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
return []
|
||||
}
|
||||
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
|
||||
|
||||
if (!isEmpty(knowledgeReferences)) {
|
||||
window.keyv.remove(`knowledge-search-${message.id}`)
|
||||
logger.debug(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
|
||||
return knowledgeReferences
|
||||
}
|
||||
logger.debug(`No knowledge base references found in cache for ID: ${message.id}`)
|
||||
return []
|
||||
}
|
||||
|
||||
protected getCustomParameters(assistant: Assistant) {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
return acc
|
||||
}
|
||||
// Parse JSON type parameters (Legacy API clients)
|
||||
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
|
||||
// The UI stores JSON type params as strings, this function parses them before sending to API
|
||||
if (param.type === 'json') {
|
||||
const value = param.value as string
|
||||
if (value === 'undefined') {
|
||||
return { ...acc, [param.name]: undefined }
|
||||
}
|
||||
return { ...acc, [param.name]: isJSON(value) ? parseJSON(value) : value }
|
||||
}
|
||||
return {
|
||||
...acc,
|
||||
[param.name]: param.value
|
||||
}
|
||||
}, {}) || {}
|
||||
)
|
||||
}
|
||||
|
||||
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
const abortController = new AbortController()
|
||||
const abortFn = () => abortController.abort()
|
||||
|
||||
if (messageId) {
|
||||
addAbortController(messageId, abortFn)
|
||||
}
|
||||
|
||||
const cleanup = () => {
|
||||
if (messageId) {
|
||||
signalPromise.resolve?.(undefined)
|
||||
removeAbortController(messageId, abortFn)
|
||||
}
|
||||
}
|
||||
const signalPromise: {
|
||||
resolve: (value: unknown) => void
|
||||
promise: Promise<unknown>
|
||||
} = {
|
||||
resolve: () => {},
|
||||
promise: Promise.resolve()
|
||||
}
|
||||
|
||||
if (isAddEventListener) {
|
||||
signalPromise.promise = new Promise((resolve, reject) => {
|
||||
signalPromise.resolve = resolve
|
||||
if (abortController.signal.aborted) {
|
||||
reject(new Error('Request was aborted.'))
|
||||
}
|
||||
// 捕获abort事件,有些abort事件必须
|
||||
abortController.signal.addEventListener('abort', () => {
|
||||
reject(new Error('Request was aborted.'))
|
||||
})
|
||||
})
|
||||
return {
|
||||
abortController,
|
||||
cleanup,
|
||||
signalPromise
|
||||
}
|
||||
}
|
||||
return {
|
||||
abortController,
|
||||
cleanup
|
||||
}
|
||||
}
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: TSdkSpecificTool[]
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
let tools: TSdkSpecificTool[] = []
|
||||
|
||||
// If there are no tools, return an empty array
|
||||
if (!mcpTools?.length) {
|
||||
return { tools }
|
||||
}
|
||||
|
||||
// If the model supports function calling and tool usage is enabled
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
return { tools }
|
||||
}
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
import type {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import type {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import type { CompletionsContext } from '../middleware/types'
|
||||
import type { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import type { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import type { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import type { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import type { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* MixedAPIClient - 适用于可能含有多种接口类型的Provider
|
||||
*/
|
||||
export abstract class MixedBaseAPIClient extends BaseApiClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
protected abstract clients: Map<
|
||||
string,
|
||||
AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient
|
||||
>
|
||||
protected abstract defaultClient: OpenAIAPIClient
|
||||
protected abstract currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:确保client是BaseApiClient的实例
|
||||
*/
|
||||
protected isValidClient(client: unknown): client is BaseApiClient {
|
||||
return (
|
||||
client !== null &&
|
||||
client !== undefined &&
|
||||
typeof client === 'object' &&
|
||||
'createCompletions' in client &&
|
||||
'getRequestTransformer' in client &&
|
||||
'getResponseChunkTransformer' in client
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
protected abstract getClient(model: Model): BaseApiClient
|
||||
|
||||
/**
|
||||
* 根据模型选择合适的client并委托调用
|
||||
*/
|
||||
public getClientForModel(model: Model): BaseApiClient {
|
||||
this.currentClient = this.getClient(model)
|
||||
return this.currentClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 重写基类方法,返回内部实际使用的客户端类型
|
||||
*/
|
||||
public override getClientCompatibilityType(model?: Model): string[] {
|
||||
if (!model) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从SDK payload中提取模型ID
|
||||
*/
|
||||
protected extractModelFromPayload(payload: SdkParams): string | null {
|
||||
// 不同的SDK可能有不同的字段名
|
||||
if ('model' in payload && typeof payload.model === 'string') {
|
||||
return payload.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
// ============ BaseApiClient 的抽象方法 ============
|
||||
|
||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||
// 尝试从payload中提取模型信息来选择client
|
||||
const modelId = this.extractModelFromPayload(payload)
|
||||
if (modelId) {
|
||||
const modelObj = { id: modelId } as Model
|
||||
const targetClient = this.getClient(modelObj)
|
||||
return targetClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
// 如果无法从payload中提取模型,使用当前设置的client
|
||||
return this.currentClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.currentClient.generateImage(params)
|
||||
}
|
||||
|
||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
const client = model ? this.getClient(model) : this.currentClient
|
||||
return client.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
async listModels(): Promise<SdkModel[]> {
|
||||
// 可以聚合所有client的模型,或者使用默认client
|
||||
return this.defaultClient.listModels()
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<SdkInstance> {
|
||||
return this.currentClient.getSdkInstance()
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
}
|
||||
|
||||
buildSdkMessages(
|
||||
currentReqMessages: SdkMessageParam[],
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls?: SdkToolCall[]
|
||||
): SdkMessageParam[] {
|
||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
}
|
||||
|
||||
estimateMessageTokens(message: SdkMessageParam): number {
|
||||
return this.currentClient.estimateMessageTokens(message)
|
||||
}
|
||||
|
||||
convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): SdkMessageParam | undefined {
|
||||
const client = this.getClient(model)
|
||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
}
|
||||
@@ -1,220 +0,0 @@
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { AihubmixAPIClient } from '../aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '../ApiClientFactory'
|
||||
import { AwsBedrockAPIClient } from '../aws/AwsBedrockAPIClient'
|
||||
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '../gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '../newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
|
||||
import { PPIOAPIClient } from '../ppio/PPIOAPIClient'
|
||||
|
||||
// 为工厂测试创建最小化 provider 的辅助函数
|
||||
// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端
|
||||
// 其他字段会传递给客户端构造函数,但不影响工厂逻辑
|
||||
const createTestProvider = (id: string, type: string): Provider => ({
|
||||
id,
|
||||
type: type as Provider['type'],
|
||||
name: '',
|
||||
apiKey: '',
|
||||
apiHost: '',
|
||||
models: []
|
||||
})
|
||||
|
||||
// Mock 所有客户端模块
|
||||
vi.mock('../aihubmix/AihubmixAPIClient', () => ({
|
||||
AihubmixAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../anthropic/AnthropicAPIClient', () => ({
|
||||
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../anthropic/AnthropicVertexClient', () => ({
|
||||
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../gemini/GeminiAPIClient', () => ({
|
||||
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../gemini/VertexAPIClient', () => ({
|
||||
VertexAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../newapi/NewAPIClient', () => ({
|
||||
NewAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../openai/OpenAIApiClient', () => ({
|
||||
OpenAIAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../openai/OpenAIResponseAPIClient', () => ({
|
||||
OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({
|
||||
getClient: vi.fn().mockReturnThis()
|
||||
}))
|
||||
}))
|
||||
vi.mock('../ppio/PPIOAPIClient', () => ({
|
||||
PPIOAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../aws/AwsBedrockAPIClient', () => ({
|
||||
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the models config to prevent circular dependency issues
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
findTokenLimit: vi.fn(),
|
||||
isReasoningModel: vi.fn(),
|
||||
isOpenAILLMModel: vi.fn(),
|
||||
SYSTEM_MODELS: {
|
||||
silicon: [],
|
||||
defaultModel: []
|
||||
},
|
||||
isOpenAIModel: vi.fn(() => false),
|
||||
qwenModel: {}
|
||||
}))
|
||||
|
||||
describe('ApiClientFactory', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('create', () => {
|
||||
// 测试特殊 ID 的客户端创建
|
||||
it('should create AihubmixAPIClient for aihubmix provider', () => {
|
||||
const provider = createTestProvider('aihubmix', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create NewAPIClient for new-api provider', () => {
|
||||
const provider = createTestProvider('new-api', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(NewAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create PPIOAPIClient for ppio provider', () => {
|
||||
const provider = createTestProvider('ppio', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(PPIOAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试标准类型的客户端创建
|
||||
it('should create OpenAIAPIClient for openai type', () => {
|
||||
const provider = createTestProvider('custom-openai', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create OpenAIResponseAPIClient for azure-openai type', () => {
|
||||
const provider = createTestProvider('azure-openai', 'azure-openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create OpenAIResponseAPIClient for openai-response type', () => {
|
||||
const provider = createTestProvider('response', 'openai-response')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create GeminiAPIClient for gemini type', () => {
|
||||
const provider = createTestProvider('gemini', 'gemini')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(GeminiAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create VertexAPIClient for vertexai type', () => {
|
||||
const provider = createTestProvider('vertex', 'vertexai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(VertexAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create AnthropicAPIClient for anthropic type', () => {
|
||||
const provider = createTestProvider('anthropic', 'anthropic')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AnthropicAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
it('should create AwsBedrockAPIClient for aws-bedrock type', () => {
|
||||
const provider = createTestProvider('aws-bedrock', 'aws-bedrock')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(AwsBedrockAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试默认情况
|
||||
it('should create OpenAIAPIClient as default for unknown type', () => {
|
||||
const provider = createTestProvider('unknown', 'unknown-type')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试边界条件
|
||||
it('should handle provider with minimal configuration', () => {
|
||||
const provider = createTestProvider('minimal', 'openai')
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
expect(OpenAIAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
|
||||
// 测试特殊 ID 优先级高于类型
|
||||
it('should prioritize special ID over type', () => {
|
||||
const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic
|
||||
|
||||
const client = ApiClientFactory.create(provider)
|
||||
|
||||
// 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient
|
||||
expect(AihubmixAPIClient).toHaveBeenCalledWith(provider)
|
||||
expect(AnthropicAPIClient).not.toHaveBeenCalled()
|
||||
expect(client).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,38 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { normalizeAzureOpenAIEndpoint } from '../openai/azureOpenAIEndpoint'
|
||||
|
||||
describe('normalizeAzureOpenAIEndpoint', () => {
|
||||
it.each([
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/openai/v1/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
},
|
||||
{
|
||||
apiHost: 'https://example.openai.azure.com/OPENAI/V1',
|
||||
expectedEndpoint: 'https://example.openai.azure.com'
|
||||
}
|
||||
])('strips trailing /openai from $apiHost', ({ apiHost, expectedEndpoint }) => {
|
||||
expect(normalizeAzureOpenAIEndpoint(apiHost)).toBe(expectedEndpoint)
|
||||
})
|
||||
})
|
||||
@@ -1,353 +0,0 @@
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||
import type { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
defaultModel: [
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' }
|
||||
],
|
||||
zhipu: [],
|
||||
silicon: [],
|
||||
openai: [],
|
||||
anthropic: [],
|
||||
gemini: []
|
||||
},
|
||||
isOpenAIModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
||||
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
||||
isGeminiLLMModel: vi.fn().mockReturnValue(false),
|
||||
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
|
||||
isVisionModel: vi.fn().mockReturnValue(false),
|
||||
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isWebSearchModel: vi.fn().mockReturnValue(false),
|
||||
findTokenLimit: vi.fn().mockReturnValue(4096),
|
||||
isFunctionCallingModel: vi.fn().mockReturnValue(false),
|
||||
DEFAULT_MAX_TOKENS: 4096,
|
||||
qwenModel: {}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/FileManager', () => ({
|
||||
default: class {
|
||||
static async read() {
|
||||
return 'test content'
|
||||
}
|
||||
static async write() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/TokenService', () => ({
|
||||
estimateTextTokens: vi.fn().mockReturnValue(100)
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: vi.fn().mockReturnValue({
|
||||
debug: vi.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
silly: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// 到底是谁想出来的在服务层调用 React Hook ?????????
|
||||
// Mock additional services and hooks that might be imported
|
||||
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
||||
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
|
||||
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
||||
privateKey: 'test-key',
|
||||
clientEmail: 'test@example.com'
|
||||
}),
|
||||
isVertexAIConfigured: vi.fn().mockReturnValue(true),
|
||||
isVertexProvider: vi.fn().mockReturnValue(true)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn().mockReturnValue({}),
|
||||
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/abortController', () => ({
|
||||
addAbortController: vi.fn(),
|
||||
removeAbortController: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({})),
|
||||
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google-cloud/vertexai', () => ({
|
||||
VertexAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
||||
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
|
||||
const MockAnthropicVertexClient = vi.fn()
|
||||
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
||||
return {
|
||||
AnthropicVertexClient: MockAnthropicVertexClient
|
||||
}
|
||||
})
|
||||
|
||||
// Helper to create test provider
|
||||
const createTestProvider = (id: string, type: string): Provider => ({
|
||||
id,
|
||||
type: type as Provider['type'],
|
||||
name: 'Test Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.test.com',
|
||||
models: []
|
||||
})
|
||||
|
||||
// Helper to create test model
|
||||
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
|
||||
id,
|
||||
name: 'Test Model',
|
||||
provider: provider || 'test',
|
||||
type: [],
|
||||
group: 'test',
|
||||
endpoint_type: endpointType as EndpointType
|
||||
})
|
||||
|
||||
describe('Client Compatibility Types', () => {
|
||||
let openaiProvider: Provider
|
||||
let anthropicProvider: Provider
|
||||
let geminiProvider: Provider
|
||||
let azureProvider: Provider
|
||||
let aihubmixProvider: Provider
|
||||
let newApiProvider: Provider
|
||||
let vertexProvider: Provider
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
openaiProvider = createTestProvider('openai', 'openai')
|
||||
anthropicProvider = createTestProvider('anthropic', 'anthropic')
|
||||
geminiProvider = createTestProvider('gemini', 'gemini')
|
||||
azureProvider = createTestProvider('azure-openai', 'azure-openai')
|
||||
aihubmixProvider = createTestProvider('aihubmix', 'openai')
|
||||
newApiProvider = createTestProvider('new-api', 'openai')
|
||||
vertexProvider = createTestProvider('vertex', 'vertexai')
|
||||
})
|
||||
|
||||
describe('Direct API Clients', () => {
|
||||
it('should return correct compatibility type for OpenAIAPIClient', () => {
|
||||
const client = new OpenAIAPIClient(openaiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for AnthropicAPIClient', () => {
|
||||
const client = new AnthropicAPIClient(anthropicProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for GeminiAPIClient', () => {
|
||||
const client = new GeminiAPIClient(geminiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Decorator Pattern API Clients', () => {
|
||||
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const testModel = createTestModel('gpt-4', 'azure-openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return NewAPIClient for NewAPIClient without model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['NewAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for NewAPIClient with model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return VertexAPIClient for VertexAPIClient without model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for VertexAPIClient with model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
|
||||
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Middleware Compatibility Logic', () => {
|
||||
it('should correctly identify OpenAI compatible clients', () => {
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 94
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(openaiTypes)).toBe(true)
|
||||
expect(isOpenAICompatible(responseTypes)).toBe(true)
|
||||
})
|
||||
|
||||
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
|
||||
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
|
||||
const anthropicTypes = anthropicClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 101
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle non-compatible clients correctly', () => {
|
||||
const geminiClient = new GeminiAPIClient(geminiProvider)
|
||||
const geminiTypes = geminiClient.getClientCompatibilityType()
|
||||
|
||||
// Test that Gemini is not OpenAI compatible
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
// Test that Gemini is not Anthropic/OpenAIResponse compatible
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(geminiTypes)).toBe(false)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Factory Integration', () => {
|
||||
it('should return correct compatibility types for factory-created clients', () => {
|
||||
const testCases = [
|
||||
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
|
||||
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
|
||||
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
|
||||
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
|
||||
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
|
||||
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
|
||||
]
|
||||
|
||||
testCases.forEach(({ provider, expectedType }) => {
|
||||
const client = ApiClientFactory.create(provider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toContain(expectedType)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,96 +0,0 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient'
|
||||
import type { BaseApiClient } from '../BaseApiClient'
|
||||
import { GeminiAPIClient } from '../gemini/GeminiAPIClient'
|
||||
import { MixedBaseAPIClient } from '../MixedBaseApiClient'
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient'
|
||||
|
||||
/**
|
||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||
*/
|
||||
export class AihubmixAPIClient extends MixedBaseAPIClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
protected clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
protected defaultClient: OpenAIAPIClient
|
||||
protected currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
const providerExtraHeaders = {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化各个client - 现在有类型安全
|
||||
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
|
||||
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
|
||||
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
|
||||
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('default', defaultClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = defaultClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
protected getClient(model: Model): BaseApiClient {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Claude client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// gemini开头 且不以-nothink、-search结尾
|
||||
if (
|
||||
(id.startsWith('gemini') || id.startsWith('imagen')) &&
|
||||
!id.endsWith('-nothink') &&
|
||||
!id.endsWith('-search') &&
|
||||
!id.includes('embedding')
|
||||
) {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Gemini client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// OpenAI系列模型 不包含gpt-oss
|
||||
if (isOpenAILLMModel(model) && !model.id.includes('gpt-oss')) {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('OpenAI client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
return this.defaultClient as BaseApiClient
|
||||
}
|
||||
}
|
||||
@@ -1,788 +0,0 @@
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import type {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
WebSearchTool20250305
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import type {
|
||||
ContentBlock,
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsBase,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlockParam,
|
||||
ThinkingBlockParam,
|
||||
ThinkingConfigParam,
|
||||
ToolUnion,
|
||||
ToolUseBlockParam,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import type AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type {
|
||||
Assistant,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import { EFFORT_RATIO, FILE_TYPE, WEB_SEARCH_SOURCE } from '@renderer/types'
|
||||
import type {
|
||||
ErrorChunk,
|
||||
LLMWebSearchCompleteChunk,
|
||||
LLMWebSearchInProgressChunk,
|
||||
MCPToolCreatedChunk,
|
||||
TextDeltaChunk,
|
||||
TextStartChunk,
|
||||
ThinkingDeltaChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { type Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import type { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import type { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic | AnthropicVertex,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkMessageParam,
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
oauthToken: string | undefined = undefined
|
||||
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
if (this.provider.authType === 'oauth') {
|
||||
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
}
|
||||
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
if (this.provider.authType === 'oauth') {
|
||||
payload.system = buildClaudeCodeSystemMessage(payload.system)
|
||||
}
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
// prevent auto appended /v1. It's included in baseUrl.
|
||||
const response = await sdk.models.list({ path: '/models' })
|
||||
return response.data
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async getEmbeddingDimensions(): Promise<number> {
|
||||
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTemperature(assistant, model)
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return super.getTopP(assistant, model)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
private static isValidBase64ImageMediaType(mime: string): mime is Base64ImageSource['media_type'] {
|
||||
return ['image/jpeg', 'image/png', 'image/gif', 'image/webp'].includes(mime)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: textContent
|
||||
}
|
||||
]
|
||||
|
||||
if (imageContents.length > 0) {
|
||||
for (const imageContent of imageContents) {
|
||||
const base64Data = await window.api.file.base64Image(imageContent.fileId + imageContent.fileExt)
|
||||
base64Data.mime = base64Data.mime.replace('jpg', 'jpeg')
|
||||
if (AnthropicAPIClient.isValidBase64ImageMediaType(base64Data.mime)) {
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
} else {
|
||||
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FILE_TYPE.TEXT, FILE_TYPE.DOCUMENT].some((type) => file.type === type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext, true)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AnthropicSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implementing abstract methods from BaseApiClient
|
||||
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
return mcpTool
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AnthropicSdkMessageParam[],
|
||||
output: Anthropic.Message,
|
||||
toolResults: AnthropicSdkMessageParam[]
|
||||
): AnthropicSdkMessageParam[] {
|
||||
const assistantMessage: AnthropicSdkMessageParam = {
|
||||
role: output.role,
|
||||
content: convertContentBlocksToParams(output.content)
|
||||
}
|
||||
|
||||
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
newMessages.push(...toolResults)
|
||||
}
|
||||
return newMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
return message.content
|
||||
.map((content) => {
|
||||
switch (content.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(content.text)
|
||||
case 'image':
|
||||
if (content.source.type === 'base64') {
|
||||
return estimateTextTokens(content.source.data)
|
||||
} else {
|
||||
return estimateTextTokens(content.source.url)
|
||||
}
|
||||
case 'tool_use':
|
||||
return estimateTextTokens(JSON.stringify(content.input))
|
||||
case 'tool_result':
|
||||
return estimateTextTokens(JSON.stringify(content.content))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
|
||||
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||
const messageParam: AnthropicSdkMessageParam = {
|
||||
role: message.role,
|
||||
content: convertContentBlocksToParams(message.content)
|
||||
}
|
||||
return messageParam
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
logger.debug(`Attaching stream listener to raw output`)
|
||||
// 专用的Anthropic事件处理
|
||||
const anthropicListener = listener as AnthropicStreamListener
|
||||
// 检查是否为MessageStream
|
||||
if (rawOutput instanceof MessageStream) {
|
||||
logger.debug(`Detected Anthropic MessageStream, attaching specialized listener`)
|
||||
|
||||
if (listener.onStart) {
|
||||
listener.onStart()
|
||||
}
|
||||
|
||||
if (listener.onChunk) {
|
||||
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||
listener.onChunk!(event)
|
||||
})
|
||||
}
|
||||
|
||||
if (anthropicListener.onContentBlock) {
|
||||
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||
}
|
||||
|
||||
if (listener.onEnd) {
|
||||
rawOutput.on('end', () => {
|
||||
listener.onEnd!()
|
||||
})
|
||||
}
|
||||
|
||||
if (listener.onError) {
|
||||
rawOutput.on('error', (error: Error) => {
|
||||
listener.onError!(error)
|
||||
})
|
||||
}
|
||||
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
anthropicListener.onMessage(rawOutput)
|
||||
}
|
||||
|
||||
// 对于非MessageStream响应
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AnthropicSdkParams
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const commonParams: MessageCreateParamsBase = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
stream: streamOutput,
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
// 注意:用户自定义参数总是应该覆盖其他参数
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||
return () => {
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||
return {
|
||||
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
if (typeof rawChunk === 'string') {
|
||||
try {
|
||||
rawChunk = JSON.parse(rawChunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', { rawChunk, error })
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
switch (rawChunk.type) {
|
||||
case 'message': {
|
||||
let i = 0
|
||||
let hasTextContent = false
|
||||
let hasThinkingContent = false
|
||||
|
||||
for (const content of rawChunk.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
if (!hasTextContent) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} as TextStartChunk)
|
||||
hasTextContent = true
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: content.text
|
||||
} as TextDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[i] = content
|
||||
i++
|
||||
break
|
||||
}
|
||||
case 'thinking': {
|
||||
if (!hasThinkingContent) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
hasThinkingContent = true
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: content.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: content.content,
|
||||
source: WEB_SEARCH_SOURCE.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: Object.values(toolCalls)
|
||||
} as MCPToolCreatedChunk)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'content_block_start': {
|
||||
const contentBlock = rawChunk.content_block
|
||||
switch (contentBlock.type) {
|
||||
case 'server_tool_use': {
|
||||
if (contentBlock.name === 'web_search') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
} as LLMWebSearchInProgressChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
if (
|
||||
contentBlock.content &&
|
||||
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
} as ErrorChunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||
source: WEB_SEARCH_SOURCE.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[rawChunk.index] = contentBlock
|
||||
break
|
||||
}
|
||||
case 'text': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} as TextStartChunk)
|
||||
break
|
||||
}
|
||||
case 'thinking':
|
||||
case 'redacted_thinking': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_delta': {
|
||||
const messageDelta = rawChunk.delta
|
||||
switch (messageDelta.type) {
|
||||
case 'text_delta': {
|
||||
if (messageDelta.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: messageDelta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking_delta': {
|
||||
if (messageDelta.thinking) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: messageDelta.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'input_json_delta': {
|
||||
if (messageDelta.partial_json) {
|
||||
accumulatedJson += messageDelta.partial_json
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_stop': {
|
||||
const toolCall = toolCalls[rawChunk.index]
|
||||
if (toolCall) {
|
||||
try {
|
||||
toolCall.input = accumulatedJson ? JSON.parse(accumulatedJson) : {}
|
||||
logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
} catch (error) {
|
||||
logger.error('Error parsing tool call input:', error as Error)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'message_delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||
*/
|
||||
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||
return contentBlocks.map((block): ContentBlockParam => {
|
||||
switch (block.type) {
|
||||
case 'text':
|
||||
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||
return {
|
||||
type: 'text',
|
||||
text: block.text
|
||||
} satisfies TextBlockParam
|
||||
case 'tool_use':
|
||||
// ToolUseBlock -> ToolUseBlockParam
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ToolUseBlockParam
|
||||
case 'thinking':
|
||||
// ThinkingBlock -> ThinkingBlockParam
|
||||
return {
|
||||
type: 'thinking',
|
||||
thinking: block.thinking,
|
||||
signature: block.signature
|
||||
} satisfies ThinkingBlockParam
|
||||
case 'redacted_thinking':
|
||||
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||
return {
|
||||
type: 'redacted_thinking',
|
||||
data: block.data
|
||||
} satisfies RedactedThinkingBlockParam
|
||||
case 'server_tool_use':
|
||||
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||
return {
|
||||
type: 'server_tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ServerToolUseBlockParam
|
||||
case 'web_search_tool_result':
|
||||
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||
return {
|
||||
type: 'web_search_tool_result',
|
||||
tool_use_id: block.tool_use_id,
|
||||
content: block.content
|
||||
} satisfies WebSearchToolResultBlockParam
|
||||
default:
|
||||
return block as ContentBlockParam
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
import type Anthropic from '@anthropic-ai/sdk'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicAPIClient } from './AnthropicAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicVertexClient')
|
||||
|
||||
export class AnthropicVertexClient extends AnthropicAPIClient {
|
||||
sdkInstance: AnthropicVertex | undefined = undefined
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
private formatApiHost(host: string): string {
|
||||
const forceUseOriginalHost = () => {
|
||||
return host.endsWith('/')
|
||||
}
|
||||
|
||||
if (!host) {
|
||||
return host
|
||||
}
|
||||
|
||||
return forceUseOriginalHost() ? host : `${host}/v1/`
|
||||
}
|
||||
|
||||
override getBaseURL() {
|
||||
return this.formatApiHost(this.provider.apiHost)
|
||||
}
|
||||
|
||||
override async getSdkInstance(): Promise<AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||
|
||||
this.sdkInstance = new AnthropicVertex({
|
||||
projectId: projectId,
|
||||
region: location,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: authHeaders,
|
||||
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
throw new Error('Vertex AI does not support listModels method.')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||
const now = Date.now()
|
||||
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||
return this.authHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||
|
||||
return this.authHeaders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get auth headers:', error)
|
||||
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,51 +0,0 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type { OpenAISdkParams, OpenAISdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
export class CherryAiAPIClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAISdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAISdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
options = options || {}
|
||||
options.headers = options.headers || {}
|
||||
|
||||
const signature = await window.api.cherryai.generateSignature({
|
||||
method: 'POST',
|
||||
path: '/chat/completions',
|
||||
query: '',
|
||||
body: payload
|
||||
})
|
||||
|
||||
options.headers = {
|
||||
...options.headers,
|
||||
...signature
|
||||
}
|
||||
|
||||
// @ts-ignore - SDK参数可能有额外的字段
|
||||
return await sdk.chat.completions.create(payload, options)
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(): string[] {
|
||||
return ['CherryAiAPIClient']
|
||||
}
|
||||
|
||||
public async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
const models = ['Qwen/Qwen3-8B', 'Qwen/Qwen3-Next-80B-A3B-Instruct']
|
||||
|
||||
const created = Date.now()
|
||||
return models.map((id) => ({
|
||||
id,
|
||||
owned_by: 'cherryai',
|
||||
object: 'model' as const,
|
||||
created
|
||||
}))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user