mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-07-03 12:27:41 +08:00
refactor(ai-service): consolidate AI runtime to main process (#14911)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: fullex <106392080+0xfullex@users.noreply.github.com> Signed-off-by: suyao <sy20010504@gmail.com>
This commit is contained in:
5
.changeset/aicore-remove-prompt-tool-use.md
Normal file
5
.changeset/aicore-remove-prompt-tool-use.md
Normal file
@@ -0,0 +1,5 @@
|
||||
---
|
||||
'@cherrystudio/ai-core': patch
|
||||
---
|
||||
|
||||
Remove the prompt-based tool-use plugin end-to-end. Tool use now relies solely on native provider tool calling, so `promptToolUsePlugin` (with its `StreamEventManager`, `ToolExecutor`, and tag-extraction helpers) and the public exports `ToolUseRequestContext` and `AiRequestMetadata.isPromptToolUse` are gone. Also switch the provider cache from `lru-cache` to `quick-lru`.
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -85,3 +85,4 @@ test-results
|
||||
YOUR_MEMORY_FILE_PATH
|
||||
|
||||
.sessions/
|
||||
.devtools
|
||||
|
||||
@@ -24,7 +24,10 @@
|
||||
|
||||
| Document | Description |
|
||||
|----------|-------------|
|
||||
| [AI Core Architecture](./references/ai-core-architecture.md) | Complete data flow and architecture from user input to LLM response |
|
||||
| [AI Reference](./references/ai/README.md) | Main-process AI pipeline: stream manager, agent loop, providers, tools |
|
||||
| [Core Architecture](./references/ai/core-architecture.md) | End-to-end call flow from user input to LLM response |
|
||||
| [Stream Manager](./references/ai/stream-manager.md) | Active-stream registry, broker, reconnect, persistence |
|
||||
| [Adapter Family](./references/ai/adapter-family.md) | How endpoint → `@ai-sdk/*` package routing is decided |
|
||||
|
||||
### Data System
|
||||
|
||||
|
||||
BIN
docs/image-1.png
Normal file
BIN
docs/image-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 98 KiB |
BIN
docs/image.png
Normal file
BIN
docs/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 125 KiB |
File diff suppressed because it is too large
Load Diff
152
docs/references/ai/README.md
Normal file
152
docs/references/ai/README.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# AI Reference
|
||||
|
||||
This is the entry point for the AI pipeline in Cherry Studio v2 — the
|
||||
main-process service that owns every LLM call (chat streams, agent loops,
|
||||
translate, summarisation) and the renderer-side transport that connects to it.
|
||||
|
||||
## Quick navigation
|
||||
|
||||
### Top-level architecture
|
||||
|
||||
| Document | What it covers |
|
||||
|---|---|
|
||||
| [Core Architecture](./core-architecture.md) | End-to-end call flow: `Ai_Stream_Open` IPC → context provider → AiStreamManager → Agent loop → `@ai-sdk/*` → broadcast / persist |
|
||||
| [Stream Manager](./stream-manager.md) | Active-stream registry, listeners, reconnect, abort, abort-and-restart steering, persistence backends |
|
||||
| [Agent Session Runtime](./agent-session-runtime.md) | Agent-session host/driver split, `pendingTurns` follow-up queue, resume token persistence, Claude Code driver fallback |
|
||||
| [Adapter Family](./adapter-family.md) | How `provider.endpointConfigs[ep].adapterFamily` picks the right `@ai-sdk/*` package per request |
|
||||
|
||||
### Subsystems
|
||||
|
||||
| Document | What it covers |
|
||||
|---|---|
|
||||
| [Agent Loop](./agent-loop.md) | Main-process `Agent.stream()`: single-pass stream, hook composition, observer pattern, error/abort semantics |
|
||||
| [Params Pipeline](./params-pipeline.md) | `buildAgentParams` + `RequestFeature` model: how capabilities, plugins, tools, and provider-specific quirks are composed |
|
||||
| [Tool Registry](./tool-registry.md) | Built-in tools (knowledge / web search), MCP tools, meta-tools (`tool_search` / `tool_inspect` / `tool_invoke` / `tool_exec`), deferred exposition |
|
||||
| [Provider Resolution](./provider-resolution.md) | `Provider.endpointConfigs` schema, endpoint resolution chain, variant suffixes, custom provider extensions (aihubmix, newapi) |
|
||||
| [Observability (trace / telemetry)](./observability.md) | `AiSdkSpanAdapter`, root span propagation, OTel attribute shape, local span projection, sinks |
|
||||
|
||||
### Renderer-side glue
|
||||
|
||||
| Document | What it covers |
|
||||
|---|---|
|
||||
| [IPC Transport](./ipc-transport.md) | `useChat` + `IpcChatTransport`: `sendMessages` / `reconnectToStream`, dispatch coordinator, topic-status mirror |
|
||||
| [Execution Overlay](./execution-overlay.md) | `TopicStreamSubscription` + `useExecutionOverlay`: ref-counted attach, per-execution demux, one-shot `readUIMessageStream` per turn (the renderer half of the same merge function Main uses) |
|
||||
| [Tool Approval](./tool-approval.md) | Approval registry, Main-as-writer model, persistent decisions, `useToolApproval` hook |
|
||||
|
||||
## Where the code lives
|
||||
|
||||
> **Scope of the focused docs.** The reference documents in this folder map
|
||||
> the **chat / stream pipeline** (dispatch → stream manager → runtime →
|
||||
> tools → persistence → renderer transport). The `agents/`, `channels/`,
|
||||
> `skills/`, and `mcp/` subsystems are mapped in the tree below but do not
|
||||
> yet have dedicated deep-dive docs.
|
||||
|
||||
```
|
||||
src/main/ai/
|
||||
├── AiService.ts ← lifecycle owner, IPC handlers (generate / translate / approval)
|
||||
├── runtime/ ← AI execution backends + runtime registry
|
||||
│ ├── aiSdk/ ← Agent class, loop, observers, params/features, prompts/
|
||||
│ └── claudeCode/ ← Claude Code driver, warm query, SDK adapter
|
||||
├── agentSession/ ← agent-session topic host
|
||||
│ └── AgentSessionRuntimeService.ts
|
||||
├── agents/ ← AgentJobsService, AgentTaskJobHandler, runAgentTask, builtin/, cherryclaw/
|
||||
├── channels/ ← ChannelManager + IM adapters (discord/feishu/qq/slack/telegram/wechat) + security/
|
||||
├── streamManager/ ← AiStreamManager + listeners + persistence backends
|
||||
│ ├── AiStreamManager.ts ← registers the stream IPC (Open/Attach/Detach/Abort)
|
||||
│ ├── context/ ← ChatContextProvider implementations + dispatch
|
||||
│ ├── lifecycle/ ← chat / prompt-only stream lifecycles
|
||||
│ ├── listeners/ ← WebContents / Persistence / SSE / channel-adapter
|
||||
│ ├── persistence/ ← MessageService / TemporaryChat / Translation backends
|
||||
│ └── pipeStreamLoop.ts ← shared chunk-pipe primitive
|
||||
├── provider/ ← provider config, endpoint resolution, custom providers
|
||||
│ ├── custom/ ← aihubmix, newapi
|
||||
│ ├── config.ts ← providerToAiSdkConfig (builder table)
|
||||
│ ├── endpoint.ts ← resolveEffectiveEndpoint + adapterFamily routing
|
||||
│ ├── extensions/ ← ProviderExtension registrations
|
||||
│ └── listModels.ts ← per-provider model listing
|
||||
├── mcp/ ← McpRuntimeService / McpCatalogService, oauth/, built-in servers
|
||||
│ └── servers/ ← in-memory MCP server implementations (browser, filesystem)
|
||||
├── skills/ ← SkillService, SkillInstaller
|
||||
├── tools/ ← unified tool registry
|
||||
│ └── adapters/
|
||||
│ ├── aiSdk/ ← registry.ts, repair.ts; builtin/ (web__search/web__fetch/kb__*),
|
||||
│ │ mcp/ (server → ToolEntry sync), meta/ (tool_search/inspect/invoke;
|
||||
│ │ tool_exec defined but not injected), exposition/ (shouldDefer + applyDefer)
|
||||
│ └── claudeCode/ ← agentTools.ts (registry → Claude Code runtime)
|
||||
├── observability/ ← AI trace adapters (aiSdk / claudeCode), local projection, sinks
|
||||
├── messages/ ← UI part → AI SDK part conversion
|
||||
├── types/ ← AppProviderId, merged extension types, request types
|
||||
└── utils/ ← reasoning / model parameters / options / websearch helpers
|
||||
```
|
||||
|
||||
## How a chat turn flows
|
||||
|
||||
1. Renderer `useChat({ transport: IpcChatTransport })` calls `sendMessages` →
|
||||
IPC `Ai_Stream_Open` (`{ topicId, trigger, userMessageParts, parentAnchorId?, mentionedModelIds? }`).
|
||||
2. `AiStreamManager.onInit` registered the `Ai_Stream_Open` handler; it
|
||||
wraps the sender in a `WebContentsListener` and calls
|
||||
`dispatchStreamRequest(manager, subscriber, req)`. (The stream IPC —
|
||||
`Open`/`Attach`/`Detach`/`Abort` — lives on `AiStreamManager`, not
|
||||
`AiService`.)
|
||||
3. `dispatchStreamRequest` picks the first `ChatContextProvider` whose
|
||||
`canHandle(topicId)` matches (persistent chat / temporary / agent
|
||||
session) and calls `prepareDispatch` — that resolves models, persists
|
||||
the user message, builds listeners, and returns a `PreparedDispatch`.
|
||||
4. `AiStreamManager.send(input)` **starts** a turn (no active stream): creates
|
||||
an `ActiveStream`, launches one `StreamExecution` per model. (A chat
|
||||
resubmit on a live topic is restarted upstream — `dispatch` calls
|
||||
`abortAndAwait` first; only an agent-session follow-up takes the
|
||||
**inject** path, which just upserts listeners.)
|
||||
5. Each execution's `runExecutionLoop` calls `AiService.streamText(request,
|
||||
signal)`, which builds params (`buildAgentParams`) and constructs an `Agent`
|
||||
composing hooks from `RequestFeature[]` (anthropic cache, gateway usage
|
||||
normalisation, reasoning extraction, …), then calls `agent.stream(messages,
|
||||
signal)` to open the AI SDK stream and yield `UIMessageChunk`s.
|
||||
Agent-session runtime requests are the exception: `AiService.streamText`
|
||||
routes them to `AgentSessionRuntimeService.openTurnStream()` so the
|
||||
registered driver can own the concrete agent runtime.
|
||||
6. `pipeStreamLoop` tees the chunk stream: one branch broadcasts to listeners
|
||||
(WebContents / SSE / channel-adapter / persistence), one branch runs
|
||||
`readUIMessageStream` to accumulate a `CherryUIMessage` snapshot.
|
||||
7. On terminal (done / error / aborted / paused-for-approval), listeners get
|
||||
a typed terminal callback. `PersistenceListener` writes the final
|
||||
message via the appropriate `PersistenceBackend`.
|
||||
8. Renderer reads the persisted row through `useQuery('/topics/:id/messages')`
|
||||
and disposes its overlay.
|
||||
|
||||
## Key invariants
|
||||
|
||||
- **Topic-level addressing.** Every IPC and broadcast is keyed by `topicId`.
|
||||
A topic has at most one active stream; subscribers are equal — there's no
|
||||
"owner" window.
|
||||
- **Main owns persistence.** Renderer closing or crashing does not abort the
|
||||
stream and does not lose data — `PersistenceListener` writes on terminal
|
||||
regardless of who is listening.
|
||||
- **Tool approval is Main-authoritative.** The renderer never writes
|
||||
`approved`/`denied` parts. It posts the decision over IPC and re-reads the
|
||||
authoritative row. See [Tool Approval](./tool-approval.md).
|
||||
- **Adapter family per endpoint, not per provider.** Multi-endpoint relays
|
||||
(MiniMax, Silicon, AiHubMix, …) carry one `adapterFamily` per endpoint.
|
||||
Picking the SDK package never reads `apiHost` or provider id heuristics
|
||||
at request time. See [Adapter Family](./adapter-family.md).
|
||||
|
||||
## Related references
|
||||
|
||||
- [Service Lifecycle](../lifecycle/README.md) — `AiService` extends `BaseService`
|
||||
- [Data Layer](../data/README.md) — `MessageService`, `ModelService`,
|
||||
`ProviderService` (called from main-side AI code)
|
||||
- [Messaging](../messaging/message-system.md) — `CherryMessagePart`,
|
||||
`CherryUIMessage`, parts model
|
||||
- [Window Manager](../window-manager/README.md) — `WebContentsListener`
|
||||
attaches to whatever windows are open
|
||||
|
||||
## v2 refactor
|
||||
|
||||
The AI domain is the largest single area of the v2 refactor: the v1
|
||||
renderer aiCore tree (formerly `src/renderer/src/aiCore/`, pre-v2 layout)
|
||||
is fully deleted, with logic ported into `src/main/ai/`.
|
||||
|
||||
These reference docs are **self-contained** — they do not depend on the
|
||||
throwaway `v2-refactor-temp/` tree. (The reviewer-facing change-cluster
|
||||
narratives that live there are review logistics for the in-flight PR, and
|
||||
are removed when the v2 AI refactor merges.)
|
||||
115
docs/references/ai/adapter-family.md
Normal file
115
docs/references/ai/adapter-family.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# Adapter Family
|
||||
|
||||
`adapterFamily` is the optional field on each `EndpointConfig` that picks
|
||||
the `@ai-sdk/*` package implementing that endpoint's protocol. The runtime
|
||||
resolver reads it; the catalog seeder and the v1→v2 migrator write it. The
|
||||
schema declares it `optional`, and the resolver has a total fallback
|
||||
(`openai-compatible`) for endpoints that omit it — so no write path is
|
||||
obligated to set it.
|
||||
|
||||
## Identity stack
|
||||
|
||||
| Layer | Example | Role |
|
||||
|---|---|---|
|
||||
| `provider.id` | `minimax`, `silicon`, `my-relay` | User-facing identity, UI label, routing key |
|
||||
| `endpointType` | `openai-chat-completions`, `anthropic-messages` | URL path template + protocol family |
|
||||
| `adapterFamily` | `openai-compatible`, `anthropic`, `azure-responses` | Which `@ai-sdk/*` package implements this protocol |
|
||||
|
||||
Multi-endpoint relays (MiniMax, Silicon, AiHubMix) carry one
|
||||
`adapterFamily` per endpoint under the same `provider.id` — different
|
||||
endpoints on the same provider can route to different SDK packages.
|
||||
|
||||
## Runtime resolver
|
||||
|
||||
`src/main/ai/provider/endpoint.ts`:
|
||||
|
||||
```ts
|
||||
export function resolveAiSdkProviderId(provider, endpointType) {
|
||||
const adapterFamily = endpointType
|
||||
? provider.endpointConfigs?.[endpointType]?.adapterFamily
|
||||
: undefined
|
||||
if (adapterFamily && adapterFamily in appProviderIds) {
|
||||
return resolveProviderVariant(appProviderIds[adapterFamily], endpointType)
|
||||
}
|
||||
return appProviderIds['openai-compatible']
|
||||
}
|
||||
```
|
||||
|
||||
One signal, no heuristics. Tested with 54 cases in
|
||||
`provider/__tests__/endpoint.test.ts`.
|
||||
|
||||
## Write paths
|
||||
|
||||
`adapterFamily` is a derived value computed at row-write time, never at
|
||||
request time. One shared inference function lives at
|
||||
`packages/provider-registry/src/registry-utils.ts`:
|
||||
|
||||
```ts
|
||||
export function inferAdapterFamily(endpointType, catalogConfig?): string {
|
||||
if (catalogConfig?.adapterFamily) return catalogConfig.adapterFamily
|
||||
return ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY[endpointType] ?? 'openai-compatible'
|
||||
}
|
||||
```
|
||||
|
||||
### Endpoint-type defaults
|
||||
|
||||
| endpoint type | default adapter |
|
||||
|---|---|
|
||||
| `anthropic-messages` | `anthropic` |
|
||||
| `google-generate-content` | `google` |
|
||||
| `ollama-chat` / `ollama-generate` | `ollama` |
|
||||
| `jina-rerank` | `jina-rerank` |
|
||||
| `openai-responses` | `openai` |
|
||||
| everything else | `openai-compatible` (terminal fallback) |
|
||||
|
||||
### Write paths
|
||||
|
||||
Only two paths write `adapterFamily`; both run in the **main** process at
|
||||
row-write time:
|
||||
|
||||
1. **Catalog (new installs)** — `packages/provider-registry/data/providers.json`
|
||||
declares `adapterFamily` per endpoint per provider. The seeder copies
|
||||
it through via `buildRuntimeEndpointConfigs`.
|
||||
2. **v1 → v2 migration (existing users)** —
|
||||
`src/main/data/migration/v2/migrators/mappings/ProviderModelMappings.ts`
|
||||
looks up the catalog by legacy id and, on a miss, calls
|
||||
`inferAdapterFamily(endpointType)` for the endpoint-type default
|
||||
(`ProviderModelMigrator.ts` carries a preset's `adapterFamily` forward
|
||||
on merge). The `ANTHROPIC_MESSAGES` endpoint skips the legacy-type hint
|
||||
because v1 custom anthropic relays carried `legacy.type='openai'` even
|
||||
when the endpoint was anthropic-format.
|
||||
|
||||
The renderer's custom-provider form does **not** set `adapterFamily`:
|
||||
`ProviderEditorDrawer.tsx` writes only `baseUrl` into the endpoint config,
|
||||
leaving the field absent so the resolver's `openai-compatible` fallback
|
||||
applies. `inferAdapterFamily` has **no renderer callers** — it is invoked
|
||||
only by the migrator above.
|
||||
|
||||
## Schema
|
||||
|
||||
`src/shared/data/types/provider.ts::EndpointConfigSchema`:
|
||||
|
||||
```ts
|
||||
EndpointConfigSchema = z.object({
|
||||
baseUrl: z.string().optional(),
|
||||
adapterFamily: z.string().optional(), // optional — resolver falls back to openai-compatible
|
||||
// ... other endpoint-config fields
|
||||
})
|
||||
```
|
||||
|
||||
`packages/provider-registry/src/schemas/provider.ts::RegistryEndpointConfigSchema`
|
||||
mirrors this for catalog entries.
|
||||
|
||||
## Tests
|
||||
|
||||
| Target | File | Cases |
|
||||
|---|---|---|
|
||||
| `inferAdapterFamily` | `packages/provider-registry/src/__tests__/registry-utils.test.ts` | 5 |
|
||||
| Migrator backfill | `src/main/data/migration/v2/migrators/mappings/__tests__/ProviderModelMappings.test.ts` | 4 |
|
||||
| Runtime resolver | `src/main/ai/provider/__tests__/endpoint.test.ts` | 54 |
|
||||
| `buildRuntimeEndpointConfigs` | `packages/provider-registry/src/__tests__/registry-utils.test.ts` | 10 |
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Runtime usage: [Provider Resolution](./provider-resolution.md)
|
||||
- Catalog: `packages/provider-registry/data/providers.json`
|
||||
116
docs/references/ai/agent-loop.md
Normal file
116
docs/references/ai/agent-loop.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# Agent Loop
|
||||
|
||||
## What it is
|
||||
|
||||
`Agent` (`src/main/ai/runtime/aiSdk/Agent.ts`) wraps `@cherrystudio/ai-core`'s
|
||||
`createAgent(...).stream()` (built on the AI SDK's `ToolLoopAgent`) with a
|
||||
`composeHooks` pipeline that folds N
|
||||
independent hook contributors (per-feature plugins, AiService analytics,
|
||||
internal observers) into a single `AgentLoopHooks` object with deterministic
|
||||
ordering, then bridges one streaming pass to a `ReadableStream<UIMessageChunk>`
|
||||
with a stable id for the first emitted message.
|
||||
|
||||
The stream is **single-pass**: `Agent.stream` runs the AI SDK stream exactly
|
||||
once and pipes it through. There is no mid-stream message injection — steering
|
||||
a chat turn is handled upstream by abort-and-restart (see
|
||||
[Stream Manager](./stream-manager.md#steering)).
|
||||
|
||||
`Agent` does not know about topics, IPC, persistence, or multi-model
|
||||
fan-out. Those concerns live in the stream manager — see
|
||||
[Stream Manager](./stream-manager.md).
|
||||
|
||||
## API
|
||||
|
||||
```ts
|
||||
const agent = new Agent({
|
||||
providerId, providerSettings, modelId,
|
||||
plugins, tools, system, options,
|
||||
hookParts, // RequestFeature contributions
|
||||
messageId // stable id for the first emitted UIMessage
|
||||
})
|
||||
|
||||
const stream: ReadableStream<UIMessageChunk> = agent.stream(initialMessages, signal)
|
||||
// or (non-streaming; input is { prompt } | { messages })
|
||||
const result = await agent.generate({ messages }, signal)
|
||||
|
||||
// internal observers can also register on the agent:
|
||||
const dispose = agent.on('onStepFinish', step => { … })
|
||||
```
|
||||
|
||||
`stream()` and `generate()` share the underlying agent — only the AI SDK
|
||||
call differs. Future `runToCompletion()` / `toTool()` are placeholders;
|
||||
they don't ship in this PR.
|
||||
|
||||
## Hooks model
|
||||
|
||||
```ts
|
||||
interface AgentLoopHooks {
|
||||
onStart?: () => Promise<void> | void
|
||||
prepareStep?: PrepareStepFunction // chained
|
||||
onStepFinish?: (step) => Promise<void> | void // void-fan-out
|
||||
onToolExecutionStart?: (event) => Promise<void> | void
|
||||
onToolExecutionEnd?: (event) => Promise<void> | void
|
||||
onFinish?: () => Promise<void> | void
|
||||
onError?: (ctx) => 'retry' | 'abort'
|
||||
}
|
||||
```
|
||||
|
||||
Hook contributions come from three sources, all folded by `composeHooks`:
|
||||
|
||||
1. **Internal observers** (`Agent.on(key, fn)`) — `attachUsageObserver`
|
||||
(injects `message-metadata` chunks carrying token usage).
|
||||
2. **Feature contributions** (`hookParts` param) — each `RequestFeature`'s
|
||||
`contributeHooks(scope)` (see [Params Pipeline](./params-pipeline.md)).
|
||||
3. **Caller hooks** — `AiService` adds the analytics hook only (token-usage
|
||||
accounting via `onStepFinish` / `onFinish`). It does *not* contribute a
|
||||
root-span/trace lifecycle hook — the OTel root span is owned by
|
||||
`AiStreamManager.runExecutionLoop`.
|
||||
|
||||
Composition rules per hook key:
|
||||
|
||||
| key | rule |
|
||||
|---|---|
|
||||
| `onStart`, `onFinish`, `onStepFinish`, `onToolExecutionStart/End` | `chainVoid` — sequential `for`-loop await; per-hook throws logged and swallowed, chain continues |
|
||||
| `prepareStep` | chained — each invocation receives the previous return value |
|
||||
| `onError` | every handler invoked sequentially; any `'retry'` makes the result `'retry'`; default `abort` |
|
||||
|
||||
All void hooks share the same `chainVoid` helper in `composeHooks.ts` —
|
||||
there is no `Promise.allSettled` / parallel path.
|
||||
|
||||
Tool execution events (`onToolExecutionStart/End`) are emitted by a
|
||||
wrapper around each tool's `execute`. No released AI SDK version brackets a
|
||||
single tool's execution: v6 exposes call-level (`experimental_onToolCallStart`)
|
||||
and input-level (`onInputStart` / `onInputDelta` / `onInputAvailable`) hooks, but
|
||||
nothing around `execute` itself — so we wrap. A future SDK version may add
|
||||
Agent-level execution hooks with the same shape, at which point the wrapper is
|
||||
removed and hook signatures stay stable.
|
||||
|
||||
## Steering
|
||||
|
||||
There is no in-loop steering. `Agent.stream` makes a single AI SDK pass and
|
||||
never folds a mid-flight follow-up into the running turn — doing so mutated
|
||||
in-flight history and had no clean turn boundary. A new chat submission to a
|
||||
live topic is handled one level up by the stream manager: the dispatcher
|
||||
aborts the running turn, waits for it to persist as `paused`, and starts a
|
||||
fresh one — see [Stream Manager → Steering](./stream-manager.md#steering).
|
||||
|
||||
Agent-session runtimes are different: they queue their own follow-ups on the
|
||||
session's `pendingTurns` and interrupt between turns rather than restarting —
|
||||
see [Agent Session Runtime](./agent-session-runtime.md#live-follow-up).
|
||||
|
||||
## Error and abort
|
||||
|
||||
- `signal.aborted` is honoured throughout; aborted streams settle with
|
||||
the accumulated chunks already broadcast.
|
||||
- Thrown errors are caught and routed through `onError`. Returning
|
||||
`'retry'` is reserved for a future implementation — today the loop
|
||||
logs and aborts.
|
||||
- The writer is settled exactly once via the `then`/`catch` of the
|
||||
internal IIFE — listeners never see a half-closed stream.
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/main/ai/runtime/aiSdk/`
|
||||
- Tests: `src/main/ai/runtime/aiSdk/loop/__tests__/agentLoop.test.ts`
|
||||
- Stream manager integration: [Stream Manager](./stream-manager.md)
|
||||
- Hook contributors: [Params Pipeline](./params-pipeline.md)
|
||||
194
docs/references/ai/agent-session-runtime.md
Normal file
194
docs/references/ai/agent-session-runtime.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# Agent Session Runtime
|
||||
|
||||
## Purpose
|
||||
|
||||
Agent-session streams need a stable host for UI turns, persistence, live
|
||||
follow-ups, interrupt, and recovery. The host must not know whether the
|
||||
underlying agent uses a long-lived process, a websocket, one HTTP request
|
||||
per turn, or Claude Code's SDK `query`.
|
||||
|
||||
The boundary is:
|
||||
|
||||
- `AgentSessionRuntimeService` owns Cherry's UI/session lifecycle.
|
||||
- `AgentSessionRuntimeDriver` owns the concrete agent-session runtime lifecycle.
|
||||
|
||||
Claude Code is the first driver. Its `query`, warm query, SDK input
|
||||
queue, and `resume` handling are driver internals.
|
||||
|
||||
## Ownership
|
||||
|
||||
| Owner | Responsibility |
|
||||
|---|---|
|
||||
| `AgentChatContextProvider` | Validates the agent session, persists the user row (plus a pending assistant row on a fresh turn), and either starts a turn or enqueues a follow-up through the runtime. |
|
||||
| `AgentSessionRuntimeService` | Owns one runtime entry per session: current UI turn, pending UI queue, runtime connection, latest resume token, terminal listeners, persistence, and idle timer. |
|
||||
| `AgentSessionRuntimeDriver` | Connects to one concrete agent implementation and exposes `send`, optional `interrupt`, `close`, and an event stream. |
|
||||
| `AiStreamManager` | Keeps the normal topic stream contract: start a turn, attach a follow-up subscriber to a live turn, pause the current runtime turn, and start the next runtime turn. |
|
||||
| `AiService.streamText()` | Routes `request.runtime.kind === 'agent-session'` to `AgentSessionRuntimeService.openTurnStream()` and rejects agent-session topics that do not carry runtime metadata. |
|
||||
| `ClaudeCodeRuntimeDriver` | Converts Claude SDK messages into generic runtime events and maps opaque resume tokens to Claude SDK `resume`. |
|
||||
|
||||
## Fresh turn
|
||||
|
||||
1. Renderer sends `Ai_Stream_Open` for topic `agent-session:<sessionId>`.
|
||||
2. `AgentChatContextProvider` validates the session:
|
||||
- the session must have an agent and workspace;
|
||||
- the workspace path must pass `assertClaudeCodeWorkspaceDirectory`;
|
||||
- the agent type must have a registered runtime driver;
|
||||
- the agent must have a model.
|
||||
3. The provider atomically saves:
|
||||
- a `user` message with the submitted parts;
|
||||
- a pending `assistant` message with the selected model id.
|
||||
4. The provider calls `AgentSessionRuntimeService.beginTurn(...)`.
|
||||
5. `beginTurn()` returns:
|
||||
- a runtime persistence listener;
|
||||
- a runtime terminal listener;
|
||||
- a trace flush listener for `agent-session:${sessionId}` history files;
|
||||
- a `turnId`.
|
||||
Follow-up messages are not queued here — they live on the session
|
||||
entry's `pendingTurns`, appended by `enqueueUserMessage()`.
|
||||
6. The prepared model request includes:
|
||||
- `runtime: { kind: 'agent-session', sessionId, turnId }`;
|
||||
- `messageId` set to the pending assistant row;
|
||||
- seed `messages`: the user row plus the empty assistant row.
|
||||
7. `AiStreamManager` starts the execution. `AiService.streamText()`
|
||||
detects the runtime metadata and calls `openTurnStream()` instead of
|
||||
building a generic `Agent`.
|
||||
8. `openTurnStream()` ensures there is a runtime connection and admits
|
||||
the turn by calling `connection.send({ message })`.
|
||||
|
||||
## Live follow-up
|
||||
|
||||
If the same topic already has a live stream, `AgentChatContextProvider`
|
||||
does **not** create a new assistant placeholder and does **not** call
|
||||
`beginTurn()` again. It persists the new user row, hands the message to
|
||||
`AgentSessionRuntimeService.enqueueUserMessage(sessionId, message)`, and
|
||||
returns a `PreparedDispatch` with `models: []` so `AiStreamManager.send()`
|
||||
takes the **inject** path — which for agent sessions only upserts the new
|
||||
subscriber onto the running stream (no message is injected into the
|
||||
execution; chat's abort-and-restart does not apply here).
|
||||
|
||||
`enqueueUserMessage()` appends the message to the session entry's
|
||||
`pendingTurns`, then acts on the current turn:
|
||||
|
||||
1. if the turn is already terminal (or absent) — schedules the next turn;
|
||||
2. if the turn is mid-tool-call (`activeToolIds` non-empty) — leaves it
|
||||
alone; the next turn is scheduled when the turn settles;
|
||||
3. otherwise — requests an interrupt when safe (`connection.interrupt()`
|
||||
if the driver supports it), which terminalizes the current UI turn.
|
||||
|
||||
Once the current turn is paused or terminal, `startNextTurn()` drains the
|
||||
next message off `pendingTurns` and starts a fresh runtime turn (below).
|
||||
This keeps the renderer protocol unchanged while each driver decides how
|
||||
to interrupt its own runtime.
|
||||
|
||||
## Starting the next runtime turn
|
||||
|
||||
When a paused, aborted, or completed runtime turn still has queued
|
||||
follow-ups, `AgentSessionRuntimeService.startNextTurn()`:
|
||||
|
||||
1. shifts the next user message off the session entry's `pendingTurns`;
|
||||
2. saves a new pending assistant row;
|
||||
3. creates a fresh `turnId`;
|
||||
4. calls `AiStreamManager.startRuntimeTurn(...)` with:
|
||||
- the same topic id and model id;
|
||||
- `runtime: { kind: 'agent-session', sessionId, turnId }`;
|
||||
- seed messages containing the user row and empty assistant row.
|
||||
|
||||
The runtime connection may stay on the entry. What that means is driver
|
||||
specific: Claude Code keeps its SDK query/input queue, while another
|
||||
driver could keep a websocket or reconnect per turn.
|
||||
|
||||
## Resume token persistence
|
||||
|
||||
Drivers may emit:
|
||||
|
||||
```ts
|
||||
{ type: 'resume-token'; token: string }
|
||||
```
|
||||
|
||||
The host treats the value as opaque. It stores it as
|
||||
`entry.lastResumeToken` and passes `runtimeResumeToken` to
|
||||
`AgentSessionMessageBackend`, so the final assistant row receives the
|
||||
latest resume token at terminal time.
|
||||
|
||||
This also covers error turns: if a driver emitted a resume token and then
|
||||
failed, the assistant error row still records that token so the next
|
||||
connection can recover from the newest driver-known state.
|
||||
|
||||
User rows do not need a resume token. The durable recovery anchor is the
|
||||
latest assistant row with `runtimeResumeToken`.
|
||||
|
||||
For Claude Code, the resume token is the SDK `session_id`. The driver
|
||||
maps it to `options.resume`. This is separate from the SDK's file
|
||||
checkpointing / `rewindFiles()` feature, which uses user-message UUIDs
|
||||
to restore files.
|
||||
|
||||
## Claude Code driver
|
||||
|
||||
Normal multi-turn chat does not use `continue: true` and does not rely
|
||||
on cwd-based session discovery.
|
||||
|
||||
When `ClaudeCodeRuntimeDriver.connect()` needs to create a query, it
|
||||
asks `buildClaudeCodeQueryRequestForAgentSession(sessionId, resumeToken)`.
|
||||
The builder uses the first available value:
|
||||
|
||||
1. explicit resume token from the host;
|
||||
2. latest persisted agent-session resume token from
|
||||
`agentSessionMessageService.getLastRuntimeResumeToken(session.id)`;
|
||||
3. no resume id for a brand-new SDK session.
|
||||
|
||||
The query may come from `ClaudeCodeWarmQueryManager.consume(...)` if a
|
||||
prewarmed query is available. Otherwise the driver starts a new SDK
|
||||
query with `createClaudeQuery({ prompt: driverSdkInputQueue, options })`.
|
||||
|
||||
The driver converts Claude SDK messages into runtime events:
|
||||
|
||||
- `stream_event` / assistant/user messages -> `chunk`;
|
||||
- `system/init` -> `resume-token`;
|
||||
- `result` -> `resume-token` and `turn-complete`;
|
||||
- thrown errors -> `error`.
|
||||
|
||||
## Idle and shutdown
|
||||
|
||||
After a turn reaches terminal state, the runtime entry becomes `idle`.
|
||||
For a short idle window it keeps:
|
||||
|
||||
- the runtime connection, if it is still alive;
|
||||
- `lastResumeToken`;
|
||||
- the session entry's `pendingTurns`.
|
||||
|
||||
If a new turn arrives during that window, `beginTurn()` reuses the same
|
||||
entry and only swaps the current UI turn plus the UI pending queue.
|
||||
|
||||
When the idle timer expires, the runtime closes the entry:
|
||||
|
||||
- clears `pendingTurns`;
|
||||
- closes the runtime connection;
|
||||
- prewarms Claude Code when a latest resume token is known.
|
||||
|
||||
Service stop and destroy close all runtime entries.
|
||||
|
||||
## Removed old path
|
||||
|
||||
Claude Code is not a normal provider extension anymore:
|
||||
|
||||
- no `createClaudeCode`;
|
||||
- no `ClaudeCodeLanguageModel`;
|
||||
- no `ClaudeCodeProviderSettings`;
|
||||
- no `injectedMessageSource` in provider settings;
|
||||
- no `providerToAiSdkConfig(..., { runtimeResumeToken })` branch.
|
||||
|
||||
Any `agent-session:*` stream that reaches `AiService.streamText()`
|
||||
without runtime metadata is rejected. That fail-fast rule prevents a
|
||||
regression back to one CLI process per turn without the long-lived SDK
|
||||
input queue inside the Claude Code driver.
|
||||
|
||||
## Verification
|
||||
|
||||
Focused tests:
|
||||
|
||||
- `src/main/ai/streamManager/context/__tests__/AgentChatContextProvider.test.ts`
|
||||
- `src/main/ai/agentSession/__tests__/AgentSessionRuntimeService.test.ts`
|
||||
- `src/main/ai/runtime/claudeCode/__tests__/ClaudeCodeRuntimeDriver.test.ts`
|
||||
- `src/main/ai/__tests__/AiService.test.ts`
|
||||
- `src/main/ai/runtime/claudeCode/__tests__/streamAdapter.test.ts`
|
||||
- `src/main/ai/runtime/claudeCode/__tests__/ClaudeCodeWarmQueryManager.test.ts`
|
||||
180
docs/references/ai/core-architecture.md
Normal file
180
docs/references/ai/core-architecture.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# Core Architecture
|
||||
|
||||
End-to-end view of how a Cherry chat turn moves from user input to LLM
|
||||
response and back to UI, with pointers to the focused references for
|
||||
each subsystem.
|
||||
|
||||
## Layered view
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Renderer │
|
||||
│ │
|
||||
│ useChat({ id: topicId, transport: IpcChatTransport }) │
|
||||
│ ├─ sendMessages → window.api.ai.streamOpen │
|
||||
│ ├─ reconnectToStream → window.api.ai.streamAttach │
|
||||
│ └─ abort signal → window.api.ai.streamAbort │
|
||||
│ │
|
||||
│ History: useQuery('/topics/:id/messages') → DataApi │
|
||||
│ Topic-level state: useTopicStreamStatus → shared cache │
|
||||
│ Approval bridge: useToolApprovalBridge → window.api.ai.toolApproval│
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
↕ IPC (keyed by topicId)
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Main │
|
||||
│ │
|
||||
│ AiStreamManager (lifecycle service) — registers in onInit: │
|
||||
│ ├─ ipcHandle('Ai_Stream_Open', → dispatchStreamRequest) │
|
||||
│ ├─ ipcHandle('Ai_Stream_Attach', → this.attach) │
|
||||
│ ├─ ipcHandle('Ai_Stream_Detach', → this.detach) │
|
||||
│ └─ ipcHandle('Ai_Stream_Abort', → this.abort) │
|
||||
│ │
|
||||
│ AiService (lifecycle service) — registers: │
|
||||
│ ├─ ipcHandle('Ai_ToolApproval_Respond', <inline handler>) │
|
||||
│ └─ ipcHandle('Ai_GenerateText' / 'Ai_Translate_Open' / …) │
|
||||
│ │
|
||||
│ dispatch (src/main/ai/streamManager/context/dispatch.ts) │
|
||||
│ pick ChatContextProvider → prepareDispatch → manager.send(...) │
|
||||
│ │
|
||||
│ AiStreamManager │
|
||||
│ activeStreams: Map<topicId, ActiveStream> │
|
||||
│ listeners + executions │
|
||||
│ runs N StreamExecution loops, fan-out per chunk to listeners │
|
||||
│ │
|
||||
│ runExecutionLoop (AiStreamManager) → AiService.streamText(req,signal)│
|
||||
│ buildAgentParams: registry.selectActive + applyDeferExposition │
|
||||
│ new Agent({tools, hookParts}) — composeHooks runs inside Agent │
|
||||
│ → agent.stream(messages, signal) │
|
||||
│ pipeStreamLoop tees: │
|
||||
│ • broadcast → WebContents / SSE / channel-adapter / persistence │
|
||||
│ • readUIMessageStream → CherryUIMessage snapshot │
|
||||
│ │
|
||||
│ Terminal listeners: │
|
||||
│ PersistenceListener → MessageService / TemporaryChat / Translation
|
||||
│ WebContentsListener → wc.send(Ai_StreamDone) │
|
||||
│ ChannelAdapterListener → adapter.onStreamComplete │
|
||||
│ SseListener → res.write('[DONE]') │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
@ai-sdk/* package
|
||||
↓
|
||||
LLM provider API
|
||||
```
|
||||
|
||||
## Sequence: a fresh chat turn
|
||||
|
||||
1. User hits send. `useChat.sendMessages` calls `IpcChatTransport.sendMessages`.
|
||||
2. Transport packages `AiStreamOpenRequest`, dispatches via
|
||||
`streamDispatchCoordinator` over IPC `Ai_Stream_Open`.
|
||||
3. `AiStreamManager`'s `Ai_Stream_Open` handler (registered in `onInit`)
|
||||
wraps the sender in a `WebContentsListener` and calls
|
||||
`dispatchStreamRequest(manager, subscriber, request)`.
|
||||
4. `dispatchStreamRequest` picks the first `ChatContextProvider` whose
|
||||
`canHandle(topicId)` matches and asks it to `prepareDispatch`.
|
||||
5. The provider resolves models, persists the user message (chat) or skips
|
||||
persistence (temporary / translate), creates `PersistenceListener` per
|
||||
execution, returns `PreparedDispatch`.
|
||||
6. `dispatch` reconciles any live stream, then calls `manager.send(input)`:
|
||||
- **chat resubmit** (topic already streaming): `manager.abortAndAwait(topicId)`
|
||||
aborts the running executions and waits for their loops to settle (the
|
||||
partial persists as `paused`) before `send()` **starts** a fresh turn —
|
||||
steering is abort-and-restart, not mid-turn injection.
|
||||
- **agent-session follow-up**: the stream is left running and `send()`
|
||||
**injects** — it upserts `listeners` onto the running stream, `models`
|
||||
ignored (the message was already enqueued on the session's `pendingTurns`).
|
||||
- **no live stream**: `send()` **starts** — evict any grace-period stream,
|
||||
create an `ActiveStream`, launch one `StreamExecution` per model.
|
||||
7. For each `StreamExecution`, `AiStreamManager`'s private `runExecutionLoop`
|
||||
calls `AiService.streamText(request, signal)`, which builds params
|
||||
(`buildAgentParamsFor → buildAgentParams`: `registry.selectActive` +
|
||||
`applyDeferExposition` + per-feature hooks), constructs an `Agent`
|
||||
(`composeHooks` folds observers + caller + features inside `Agent`), and
|
||||
calls `agent.stream(messages, signal)` — which opens AI SDK's stream and
|
||||
yields `UIMessageChunk`s. Agent-session runtime requests skip the generic
|
||||
agent loop here: `AiService.streamText()` calls
|
||||
`AgentSessionRuntimeService.openTurnStream()` so the registered driver
|
||||
can own the concrete agent runtime.
|
||||
8. `pipeStreamLoop` reads the chunk stream once, tees: broadcast to
|
||||
listeners, accumulate via `readUIMessageStream`.
|
||||
9. On terminal (`done` / `error` / `aborted` / `awaiting-approval`):
|
||||
- `PersistenceListener` writes the final assistant message.
|
||||
- `WebContentsListener` broadcasts `Ai_StreamDone` to subscribed windows.
|
||||
- Shared-cache `topic.stream.statuses.<topicId>` flips to the terminal status.
|
||||
10. Renderer's `useQuery('/topics/:id/messages')` revalidates; the
|
||||
optimistic overlay is disposed.
|
||||
|
||||
## Sequence: tool approval pause + resume
|
||||
|
||||
1. AI SDK calls `tool.execute(args, toolCallContext)`. The wrapper sees
|
||||
`needsApproval(args)` returns true and the assistant's auto-approve
|
||||
policy says "ask". It writes an `approval-requested` part on the
|
||||
accumulated message and holds the promise.
|
||||
2. Manager flips status to `awaiting-approval` on the shared cache.
|
||||
3. Renderer's `useTopicAwaitingApproval(topicId)` returns true; the UI
|
||||
shows the approval card.
|
||||
4. User decides → `useToolApprovalBridge` → `Ai_ToolApproval_Respond`.
|
||||
5. Main applies the decision to the anchor row, resumes the stream
|
||||
(Claude-Agent: resolves the `canUseTool` promise; MCP: dispatches a
|
||||
`continue-conversation` so the existing stream rebroadcasts).
|
||||
6. Status flips back to `streaming`; UI hides the card.
|
||||
|
||||
See [Tool Approval](./tool-approval.md) for invariants and the
|
||||
overlay-vs-persist conditional write.
|
||||
|
||||
## Key subsystems
|
||||
|
||||
| Subsystem | Reference |
|
||||
|---|---|
|
||||
| Active-stream registry, listeners, persistence backends, reconnect, abort, grace-period eviction | [Stream Manager](./stream-manager.md) |
|
||||
| Claude Code agent-session long-lived runtime, SDK input queue, resume fallback | [Agent Session Runtime](./agent-session-runtime.md) |
|
||||
| `Agent.stream` single-pass loop, hooks model, error/abort | [Agent Loop](./agent-loop.md) |
|
||||
| `buildAgentParams`, `RequestFeature` composition, `INTERNAL_FEATURES` order | [Params Pipeline](./params-pipeline.md) |
|
||||
| Tool registry, MCP sync, meta-tools (`tool_search` / `tool_inspect` / `tool_invoke` / `tool_exec`), defer exposition | [Tool Registry](./tool-registry.md) |
|
||||
| `Provider.endpointConfigs`, `endpointType` resolution, variant suffixes, custom providers | [Provider Resolution](./provider-resolution.md) |
|
||||
| `adapterFamily` field, runtime resolver, write paths (catalog / migrator) | [Adapter Family](./adapter-family.md) |
|
||||
| OTel span tree, `AdapterTracer`, `AiSdkSpanAdapter`, dev-tools view | [Observability](./observability.md) |
|
||||
| `IpcChatTransport`, dispatch coordinator, per-execution demux | [IPC Transport](./ipc-transport.md) |
|
||||
| Approval flow, Main-as-writer invariant, persistent decisions | [Tool Approval](./tool-approval.md) |
|
||||
|
||||
## Invariants
|
||||
|
||||
- **Topic-level addressing.** Every IPC, broadcast, and shared-cache
|
||||
entry is keyed by `topicId`. A topic has at most one active stream;
|
||||
subscribers are equal — there is no "owner" window.
|
||||
- **Main owns persistence.** Renderer closing or crashing does not abort
|
||||
the stream or lose data. `PersistenceListener` writes on terminal
|
||||
regardless of subscriber state.
|
||||
- **Main owns approval state.** The renderer is never a writer.
|
||||
- **Adapter family is per-endpoint.** Multi-endpoint relays may use
|
||||
different `@ai-sdk/*` packages on different endpoints under the same
|
||||
`provider.id`.
|
||||
- **`tools/applies` predicates are pure.** They run on every
|
||||
`selectActive` pass; side effects there break tool selection
|
||||
determinism.
|
||||
- **Features must not mutate `RequestScope`.** It is shared across all
|
||||
features for a single request.
|
||||
|
||||
## Code map
|
||||
|
||||
```
|
||||
src/main/ai/
|
||||
├── AiService.ts ← lifecycle owner, IPC entry (generate / translate / approval)
|
||||
├── runtime/ ← execution backends: runtime/aiSdk (Agent + params), runtime/claudeCode
|
||||
├── agentSession/ ← agent-session topic host
|
||||
├── agents/ ← AgentJobsService, AgentTaskJobHandler, runAgentTask, cherryclaw
|
||||
├── channels/ ← ChannelManager + IM adapters (discord/feishu/qq/slack/telegram/wechat) + security/
|
||||
├── streamManager/ ← AiStreamManager, listeners, persistence (registers the stream IPC)
|
||||
├── provider/ ← provider config, endpoint resolution, custom providers
|
||||
├── mcp/ ← McpRuntimeService / McpCatalogService, oauth, built-in servers
|
||||
├── skills/ ← SkillService, SkillInstaller
|
||||
├── tools/ ← unified tool registry (adapters/aiSdk + adapters/claudeCode)
|
||||
├── observability/ ← AI trace adapters, local projection, sinks
|
||||
├── messages/ ← UI part → AI SDK part conversion
|
||||
├── types/ ← AppProviderId, merged types, request types
|
||||
└── utils/ ← reasoning / model parameters / options / websearch
|
||||
|
||||
src/renderer/transport/ ← IpcChatTransport, dispatch coordinator
|
||||
src/renderer/hooks/ ← useChatWithHistory, useToolApprovalBridge, useTopicStreamStatus
|
||||
packages/aiCore/ ← @cherrystudio/ai-core (Agent + plugins + provider extensions)
|
||||
packages/provider-registry/ ← provider catalog, registry-utils (adapterFamily inference)
|
||||
```
|
||||
223
docs/references/ai/execution-overlay.md
Normal file
223
docs/references/ai/execution-overlay.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# Execution Overlay
|
||||
|
||||
The renderer-side counterpart of Main's `pipeStreamLoop`. Both sides
|
||||
use the **same pure assembler** —
|
||||
[AI SDK's `readUIMessageStream`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/read-ui-message-stream) —
|
||||
to turn the chunk stream into a `CherryUIMessage`. Main writes the
|
||||
result to disk; the renderer paints it onto the chat surface as an
|
||||
overlay above the SWR-backed history.
|
||||
|
||||
## Why the same merge function on both sides
|
||||
|
||||
`UIMessageChunk` assembly is non-trivial: text deltas merge by `id`,
|
||||
reasoning blocks have their own start/delta/end, tool calls go through
|
||||
`tool-input-start` / `tool-input-delta` / `tool-input-available` /
|
||||
`tool-output-available`, dynamic data parts merge by key, multi-step
|
||||
turns carry step boundaries. Re-implementing any of this on the
|
||||
renderer would mean a second source of truth that *had* to track AI SDK
|
||||
upstream, with two ways to disagree about partial state.
|
||||
|
||||
Running the same `readUIMessageStream` on the same `UIMessageChunk`
|
||||
stream — once on Main (writing to `exec.finalMessage`), once on the
|
||||
renderer (driving the overlay) — guarantees structural agreement.
|
||||
What persists is exactly what the user saw streaming.
|
||||
|
||||
```
|
||||
Main: pipeStreamLoop(stream)
|
||||
tee()
|
||||
├─ branch A → broadcast to listeners → WebContentsListener → IPC chunks
|
||||
└─ branch B → readUIMessageStream → exec.finalMessage (writes to DB)
|
||||
▲
|
||||
│ (DB write)
|
||||
│
|
||||
Renderer: TopicStreamSubscription ┌──── readUIMessageStream → snapshot
|
||||
│ │ │ ▲
|
||||
│ ▼ │ │
|
||||
│ routes chunks by │ fed by branch stream
|
||||
│ executionId into │
|
||||
│ per-execution branches ─────┘
|
||||
▼
|
||||
branch ReadableStream → useExecutionOverlay (per execution)
|
||||
```
|
||||
|
||||
## TopicStreamSubscription
|
||||
|
||||
`src/renderer/transport/TopicStreamSubscription.ts`. A renderer
|
||||
class that owns:
|
||||
|
||||
- **One IPC attach per topic.** `attach` is ref-counted — every
|
||||
execution that calls `register(executionId)` increments the count;
|
||||
the last `unregister` triggers `detach` (deferred one microtask so a
|
||||
transient `activeExecutions` flicker doesn't detach-then-reattach).
|
||||
- **Per-execution demux.** Each `register(executionId)` returns a
|
||||
`ReadableStream<UIMessageChunk>` that contains only the chunks tagged
|
||||
with that `executionId` by Main. Multi-model parallel responses each
|
||||
get their own branch.
|
||||
- **Synchronous controller creation.** The branch's
|
||||
`ReadableStreamDefaultController` is created during the
|
||||
`new ReadableStream({ start })` call (synchronous), so chunks that
|
||||
arrived between `register` and the reader's first `read()` are
|
||||
already buffered in the stream's internal queue — late readers never
|
||||
miss replayed chunks.
|
||||
- **Terminal demux.** `Ai_StreamDone` / `Ai_StreamError` close the
|
||||
matching branch and fan out an `ExecutionTerminal` (`{ isAbort,
|
||||
isError }`) to listeners; if the payload carries `isTopicDone` or no
|
||||
`executionId`, every branch terminates together.
|
||||
|
||||
### Cancellation layering — do not conflate
|
||||
|
||||
| Layer | Owner | Action |
|
||||
|---|---|---|
|
||||
| Renderer-local subscription | `TopicStreamSubscription.unregister` / `dispose` | Closes the branch reader, drops listener ref count; Main keeps generating |
|
||||
| Generation abort | Main (via `useChatWithHistory.stop` → Chat → `Ai_Stream_Abort`) | Stops the LLM |
|
||||
|
||||
`TopicStreamSubscription` NEVER aborts the LLM. Closing all branches
|
||||
is the renderer equivalent of `streamDetach` — Main keeps streaming,
|
||||
other windows keep observing.
|
||||
|
||||
### Defensive routing
|
||||
|
||||
A chunk without `executionId` is unexpected — Main always tags chat
|
||||
chunks. As a defensive fallback, if exactly one branch is registered
|
||||
the chunk routes there; otherwise it's dropped with a warning.
|
||||
|
||||
## useExecutionOverlay
|
||||
|
||||
`src/renderer/hooks/useExecutionOverlay.ts`. The per-execution
|
||||
overlay, built on `useTopicStreamSubscription`.
|
||||
|
||||
```ts
|
||||
const { overlay, liveAssistants, disposeOverlay, reset } = useExecutionOverlay(
|
||||
topicId,
|
||||
activeExecutions, // ActiveExecution[] from useTopicStreamStatus
|
||||
uiMessages, // current DB snapshot
|
||||
{ onFinish }
|
||||
)
|
||||
```
|
||||
|
||||
### One reader per turn, zero cross-turn state
|
||||
|
||||
Each execution gets a **one-shot `readUIMessageStream` reader** per
|
||||
turn, not a stateful AI SDK `Chat`. A `Chat` carries
|
||||
`state.messages` across turns; reusing it made a new turn resume from
|
||||
the previous turn's finished assistant ("previous answer + new
|
||||
stream"). A fresh reader per turn structurally cannot pollute.
|
||||
|
||||
### The seed rule (continue-safe)
|
||||
|
||||
```ts
|
||||
function pickSeed(uiMessages, anchorMessageId): CherryUIMessage | undefined {
|
||||
if (!anchorMessageId) return undefined
|
||||
const found = uiMessages.find((m) => m.id === anchorMessageId)
|
||||
if (!found) return { id: anchorMessageId, role: 'assistant', parts: [] }
|
||||
// `readUIMessageStream` mutates `message.parts` in place, and `found` is the live
|
||||
// SWR-derived row — clone the parts so the reader only ever writes to a throwaway.
|
||||
return { ...found, parts: structuredClone(found.parts ?? []) }
|
||||
}
|
||||
```
|
||||
|
||||
The reader is seeded with the message whose id is the execution's
|
||||
`anchorMessageId`, taken from the **current DB truth** at reader-start
|
||||
time. Two cases:
|
||||
|
||||
- **Fresh placeholder** — the SQLite row has empty parts; the seed is
|
||||
effectively empty and the reader builds the message from scratch.
|
||||
- **Tool-approval / continue-conversation** — the row already carries
|
||||
the prior assistant parts (including the unresolved `tool-input` part
|
||||
the approval was on). A streamed `tool-output` chunk then merges
|
||||
cleanly onto its matching `tool-input` because they share the same
|
||||
`toolCallId`.
|
||||
|
||||
The seed is re-derived from DB on every reader start; it never carries
|
||||
across turns, and its `parts` are cloned so the reader's in-place mutation
|
||||
never touches the SWR row. Combined with the fresh reader, this is the
|
||||
**structural** anti-pollution guarantee — not "force empty parts" or "diff
|
||||
against last frame".
|
||||
|
||||
### Lifecycle
|
||||
|
||||
1. **Topic switch** — every reader is cancelled, every branch
|
||||
unregistered, `snapshots` cleared. `prevTopicRef` is checked in the
|
||||
render body so the cleanup runs synchronously before the new topic's
|
||||
readers start.
|
||||
2. **`activeExecutions` change** — diff against the current reader
|
||||
map: cancel + unregister executions no longer in the active list;
|
||||
for newly-active executions, register a branch, clear any retained
|
||||
prior snapshot, kick a new reader.
|
||||
3. **Terminal** — the branch is closed by `TopicStreamSubscription`;
|
||||
the reader's `for await` exits. The `onFinish(executionId, event)`
|
||||
callback fires with the final snapshot + `{ isAbort, isError }`.
|
||||
4. **Unmount** — every reader is cancelled, every branch unregistered.
|
||||
|
||||
### Overlay teardown is monotonic
|
||||
|
||||
`disposeOverlay(messageId)` drops exactly one snapshot entry. The chat
|
||||
shell wires this so the overlay is released **only after** the DB
|
||||
refresh promise resolves (see `.finally(() => disposeOverlay(...))` in
|
||||
`V2ChatContent`). That ordering eliminates the visible flash between
|
||||
"streaming overlay" and "persisted parts": the SWR cache holds the
|
||||
authoritative row before the overlay disappears.
|
||||
|
||||
The renderer never writes streamed parts to SWR — writing them would
|
||||
race the DB-authoritative refresh and cause flicker.
|
||||
|
||||
### Why retained snapshots after terminal
|
||||
|
||||
The hook keeps the final snapshot in `snapshots` until one of:
|
||||
|
||||
- the same execution restarts (next turn clears it),
|
||||
- the caller calls `disposeOverlay(messageId)` (post-persist handoff),
|
||||
- the caller calls `reset()` (e.g. quick-assistant clear),
|
||||
- the topic switches (effect clears all snapshots).
|
||||
|
||||
That retention lets consumers read the final frame for the brief window
|
||||
between stream-end and DB-refresh-complete without going through SWR.
|
||||
|
||||
## React binding
|
||||
|
||||
`useTopicStreamSubscription(topicId)` is the React wrapper:
|
||||
|
||||
- Lazy-init per `topicId` (idiom mirrors `useState(() => ...)`).
|
||||
- Disposed on unmount or topic switch — drops the Main listener and
|
||||
closes every branch.
|
||||
|
||||
Each mounted topic gets one `TopicStreamSubscription` instance, shared
|
||||
by every consumer in that React tree (today: `useExecutionOverlay`).
|
||||
|
||||
## Code map
|
||||
|
||||
```
|
||||
src/renderer/transport/TopicStreamSubscription.ts ← class
|
||||
src/renderer/hooks/useTopicStreamSubscription.ts ← React binding
|
||||
src/renderer/hooks/useExecutionOverlay.ts ← per-execution readers + overlay
|
||||
src/renderer/pages/home/V2ChatContent.tsx ← consumer + dispose-after-refresh
|
||||
```
|
||||
|
||||
## Invariants reviewers should check
|
||||
|
||||
1. **Same merge function on both sides.** Any code that re-implements
|
||||
chunk → message assembly on the renderer (instead of feeding
|
||||
`readUIMessageStream`) is wrong — that's where Main and renderer
|
||||
will diverge first.
|
||||
2. **One reader per turn.** No reader should be reused across
|
||||
`activeExecutions` transitions. Reusing one is what the v1 `Chat`
|
||||
bug was; the structural fix is structural.
|
||||
3. **Seed from current DB.** `pickSeed` reads `uiMessagesRef.current`
|
||||
at reader-start time. Stashing the seed on first mount and reusing
|
||||
it across turns would defeat the continue-conversation case.
|
||||
4. **Overlay disposed after DB refresh.** Any
|
||||
`disposeOverlay(messageId)` call that runs **before** the DB
|
||||
revalidation promise resolves is a flicker bug.
|
||||
5. **`TopicStreamSubscription` never aborts.** It only detaches.
|
||||
Anything in this layer that calls `Ai_Stream_Abort` is in the
|
||||
wrong place — abort belongs to `useChatWithHistory.stop`.
|
||||
6. **Ref-counted attach.** A new attach must NOT fire when another
|
||||
execution is already registered for the same topic. A new detach
|
||||
must NOT fire while any execution still has a branch.
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Main-side accumulator: [Stream Manager — `pipeStreamLoop`](./stream-manager.md#execution-loop--runexecutionloop--pipestreamloop)
|
||||
- IPC envelope: [IPC Transport](./ipc-transport.md)
|
||||
- Topic status / approval-anchor surfacing: [Tool Approval](./tool-approval.md)
|
||||
- AI SDK upstream: [`readUIMessageStream` reference](https://ai-sdk.dev/docs/reference/ai-sdk-ui/read-ui-message-stream)
|
||||
@@ -143,7 +143,7 @@ Empty / `undefined` / `null` values are dropped here — the server applies its
|
||||
own default; no client-side defaults are invented. The `'auto'` sentinel is
|
||||
**not** dropped at this stage: it's carried through and resolved to "omit the
|
||||
field" one stage later by the emitters (e.g. `toDashScopeSize` /
|
||||
`AiProvider.resolveImageSize`).
|
||||
`resolveSizeParameter` in `dashscopeTransport.ts`).
|
||||
|
||||
### 2. Transport routing hint (`paintingPipeline`)
|
||||
|
||||
@@ -154,7 +154,7 @@ field" one stage later by the emitters (e.g. `toDashScopeSize` /
|
||||
|
||||
### 3. providerOptions emitters (`buildImageProviderOptions`)
|
||||
|
||||
[`src/renderer/aiCore/utils/imageOptions.ts`](../../../src/renderer/aiCore/utils/imageOptions.ts) is a **table of per-provider emitters** that map canonical params to each vendor's real wire field names and bag key:
|
||||
[`src/main/ai/utils/imageOptions.ts`](../../../src/main/ai/utils/imageOptions.ts) is a **table of per-provider emitters** that map canonical params to each vendor's real wire field names and bag key:
|
||||
|
||||
```
|
||||
EMITTERS: Record<providerId, Emitter> // unlisted ids → diffusion fallback
|
||||
@@ -170,10 +170,10 @@ nesting, enum casing. Nowhere else.
|
||||
|
||||
### 4. The model itself
|
||||
|
||||
[`AiProvider.modernGeneratePaintingImage`](../../../src/renderer/aiCore/AiProvider.ts) hands `aiSdkParams` + `providerOptions` to the resolved image model. The model is one of two kinds, decided by the provider factory:
|
||||
[`AiService.generateImage`](../../../src/main/ai/AiService.ts) (reached via the `Ai_GenerateImage` IPC) hands `aiSdkParams` + `providerOptions` to the resolved image model. The model is one of two kinds, decided by the provider factory:
|
||||
|
||||
- **Native AI SDK image model** — `OpenAIImageModel`, `@ai-sdk/google` `.image()`, `OpenAICompatibleImageModel`. Spreads `providerOptions[key]` into the request body.
|
||||
- **Custom `ImageGenerationTransport`** — for async submit→poll vendors or non-OpenAI wire shapes (DashScope, PPIO, DMXAPI's Doubao/Wan/async-Qwen families). See [`src/renderer/aiCore/provider/custom/imageGenerationModel.ts`](../../../src/renderer/aiCore/provider/custom/imageGenerationModel.ts); each vendor's transport lives beside its provider in a per-vendor folder (e.g. [`dmxapi/dmxapiTransport.ts`](../../../src/renderer/aiCore/provider/custom/dmxapi/dmxapiTransport.ts)), with shared helpers in [`transportUtils.ts`](../../../src/renderer/aiCore/provider/custom/transportUtils.ts). Multi-backend gateways (DMXAPI) dispatch by a `{match, family}` table on the model id — see [`dmxapi/dmxapiProvider.ts`](../../../src/renderer/aiCore/provider/custom/dmxapi/dmxapiProvider.ts).
|
||||
- **Custom `ImageGenerationTransport`** — for async submit→poll vendors or non-OpenAI wire shapes (DashScope, PPIO, DMXAPI's Doubao/Wan/async-Qwen families). See [`src/main/ai/provider/custom/imageGenerationModel.ts`](../../../src/main/ai/provider/custom/imageGenerationModel.ts); each vendor's transport lives beside its provider in a per-vendor folder (e.g. [`dmxapi/dmxapiTransport.ts`](../../../src/main/ai/provider/custom/dmxapi/dmxapiTransport.ts)), with shared helpers in [`transportUtils.ts`](../../../src/main/ai/provider/custom/transportUtils.ts). Multi-backend gateways (DMXAPI) dispatch by a `{match, family}` table on the model id — see [`dmxapi/dmxapiProvider.ts`](../../../src/main/ai/provider/custom/dmxapi/dmxapiProvider.ts).
|
||||
|
||||
---
|
||||
|
||||
@@ -213,8 +213,8 @@ nesting, enum casing. Nowhere else.
|
||||
| Default population on switch | `src/renderer/pages/paintings/utils/computeModelFieldReset.ts` |
|
||||
| Param partition | `src/renderer/pages/paintings/model/canonicalGenerate.ts` |
|
||||
| Transport hint + requirePrompt | `src/renderer/pages/paintings/model/paintingPipeline.ts` |
|
||||
| providerOptions emitters | `src/renderer/aiCore/utils/imageOptions.ts` |
|
||||
| Custom transport wrapper | `src/renderer/aiCore/provider/custom/imageGenerationModel.ts` |
|
||||
| Vendor provider + transport | `src/renderer/aiCore/provider/custom/<vendor>/{<vendor>Provider,<vendor>Transport}.ts` |
|
||||
| Shared transport helpers | `src/renderer/aiCore/provider/custom/transportUtils.ts` |
|
||||
| providerOptions emitters | `src/main/ai/utils/imageOptions.ts` |
|
||||
| Custom transport wrapper | `src/main/ai/provider/custom/imageGenerationModel.ts` |
|
||||
| Vendor provider + transport | `src/main/ai/provider/custom/<vendor>/{<vendor>Provider,<vendor>Transport}.ts` |
|
||||
| Shared transport helpers | `src/main/ai/provider/custom/transportUtils.ts` |
|
||||
```
|
||||
|
||||
98
docs/references/ai/ipc-transport.md
Normal file
98
docs/references/ai/ipc-transport.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# IPC Transport
|
||||
|
||||
## What it is
|
||||
|
||||
`IpcChatTransport`
|
||||
(`src/renderer/transport/IpcChatTransport.ts`) implements AI SDK's
|
||||
`ChatTransport<CherryUIMessage>` over Electron IPC. The renderer feeds
|
||||
it into `useChat({ id: topicId, transport: ... })`. The `ChatTransport`
|
||||
interface has only two methods — `sendMessages` / `reconnectToStream`;
|
||||
the transport relays each over `window.api.ai.stream*` to Main's
|
||||
`AiStreamManager`. `cancel` is **not** a transport method: it is the
|
||||
`cancel` callback of the `ReadableStream` that `sendMessages` returns
|
||||
(AI SDK invokes it on unmount/disposal), and abort is driven by the
|
||||
request's `abortSignal`.
|
||||
|
||||
```
|
||||
useChat({ id: topicId, transport: new IpcChatTransport(defaultBody) })
|
||||
│ transport methods
|
||||
├─ sendMessages → window.api.ai.streamOpen (Ai_Stream_Open)
|
||||
├─ reconnectToStream → window.api.ai.streamAttach (Ai_Stream_Attach)
|
||||
│ returned-stream / signal callbacks
|
||||
├─ stream cancel() → window.api.ai.streamDetach (Ai_Stream_Detach)
|
||||
└─ request abort signal → window.api.ai.streamAbort (Ai_Stream_Abort)
|
||||
```
|
||||
|
||||
**Detach ≠ abort.** `cancel()` (e.g. unmount/disposal) calls `streamDetach`:
|
||||
it drops *this* subscriber while Main keeps generating and persists the
|
||||
result. Stopping generation is a separate path — the request's `abortSignal`
|
||||
firing calls `streamAbort`. Conflating the two would resurrect the v1
|
||||
"unmount → cancel → upstream abort → lost reply" bug class.
|
||||
|
||||
Per-topic chunks arrive via `onStreamChunk` listeners filtered by
|
||||
`topicId`.
|
||||
|
||||
## Triggers
|
||||
|
||||
`sendMessages` distinguishes two triggers:
|
||||
|
||||
| Trigger | What it does |
|
||||
|---|---|
|
||||
| `submit-message` | Includes `userMessageParts` (the latest message) so Main persists it |
|
||||
| `regenerate-message` | Sends `parentAnchorId` only; Main re-runs from the existing parent |
|
||||
|
||||
Cherry's transport never derives `continue-conversation` from
|
||||
message-state introspection. Approval-driven resumption goes through the
|
||||
explicit `Ai_ToolApproval_Respond` IPC handled by
|
||||
[`useToolApprovalBridge`](./tool-approval.md).
|
||||
|
||||
## Dispatch coordinator
|
||||
|
||||
`streamDispatchCoordinator` (`src/renderer/transport/streamDispatchCoordinator.ts`)
|
||||
sits between the transport and the IPC call so the `Ai_Stream_Open` ack
|
||||
(`userMessageId`, placeholder ids, executionIds) is observable to callers
|
||||
that need to join optimistic UI bubbles, rather than being thrown away by
|
||||
AI SDK's transport interface.
|
||||
|
||||
It does **not** serialize sends — there is no single-in-flight guard in the
|
||||
coordinator. Concurrency for a topic is arbitrated on the Main side: a chat
|
||||
resubmit to a live topic is aborted-and-restarted by `dispatch`
|
||||
(`AiStreamManager.abortAndAwait`), while an agent-session follow-up attaches
|
||||
to the running stream.
|
||||
|
||||
## Per-execution demux
|
||||
|
||||
The chunk stream from Main is keyed by `(topicId, executionId)`.
|
||||
`TopicStreamSubscription`
|
||||
(`src/renderer/transport/TopicStreamSubscription.ts`) owns the
|
||||
topic-level `streamAttach` / `streamDetach` with ref-counted lifecycle
|
||||
and demuxes chunks into per-execution branch `ReadableStream`s, so
|
||||
multi-model parallel responses render as separate AI SDK messages on
|
||||
the same topic. `useExecutionOverlay` consumes each branch through
|
||||
`readUIMessageStream` — the same accumulator Main runs in
|
||||
`pipeStreamLoop`, so the renderer overlay and the persisted message
|
||||
are structurally identical.
|
||||
|
||||
See [Execution Overlay](./execution-overlay.md) for the merge-function
|
||||
symmetry, seed rule, cancellation layering, and lifecycle.
|
||||
|
||||
## Topic-level subscription
|
||||
|
||||
`useTopicStreamStatus(topicId)` reads
|
||||
`topic.stream.statuses.<topicId>` from the shared cache. The cache is
|
||||
the cross-window source of truth for:
|
||||
|
||||
- `pending` / `streaming` / `awaiting-approval` / `done` / `error` / `aborted`
|
||||
- broadcast-completion anchor ids
|
||||
|
||||
`classifyTurn(status)` decodes the status into the `TurnStateFlags`
|
||||
predicates the UI consumes (`isStreamLive`, `isTurnActive`,
|
||||
`isAwaitingApproval`, `isTerminal`).
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/renderer/transport/`
|
||||
- Hook glue: `src/renderer/hooks/useChatWithHistory.ts`
|
||||
- Per-execution overlay (renderer assembler): [Execution Overlay](./execution-overlay.md)
|
||||
- Approval bridge: [Tool Approval](./tool-approval.md)
|
||||
- Main side: [Stream Manager](./stream-manager.md)
|
||||
120
docs/references/ai/observability.md
Normal file
120
docs/references/ai/observability.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Observability
|
||||
|
||||
The `src/main/ai/observability/` subsystem: OTel tracing, the local span
|
||||
projection (trace viewer), and the sink registry. "Trace / telemetry" is the
|
||||
user-facing surface; this doc covers the whole subsystem.
|
||||
|
||||
## What's instrumented
|
||||
|
||||
Every AI SDK call run through Cherry produces an OpenTelemetry span
|
||||
tree:
|
||||
|
||||
```
|
||||
chat.turn (root, created by context provider)
|
||||
├── ai.streamText (AI SDK auto)
|
||||
│ ├── ai.streamText.doStream (AI SDK auto)
|
||||
│ ├── ai.toolCall (per tool invocation) (AI SDK auto)
|
||||
│ └── ai.streamText.<step> (AI SDK auto)
|
||||
└── attributes: topicId, modelName, … (set by AiTurnTrace / AdapterTracer)
|
||||
```
|
||||
|
||||
AI SDK's `experimental_telemetry` produces the inner spans; Cherry owns
|
||||
the root span through `AiTurnTrace` so it lands in the same observability
|
||||
path without going through the AI SDK adapter.
|
||||
|
||||
The main-process observability boundary is `src/main/ai/observability`:
|
||||
|
||||
- `core/` creates Cherry-owned turn roots and common `cs.*` attributes.
|
||||
- `adapters/aiSdk/` interprets AI SDK child spans.
|
||||
- `adapters/claudeCode/` interprets Claude Code OTLP spans and logs.
|
||||
- `cache/` keeps the trace-window projection and JSONL-compatible history.
|
||||
- `sinks/` defines the extension point for local and future external export.
|
||||
|
||||
## Local history flush
|
||||
|
||||
`Message.traceId` is persisted with the assistant message row, but the span
|
||||
tree is first collected in the main-process `SpanCacheService` memory store.
|
||||
The durable history file is written by the stream terminal path, not by the
|
||||
trace UI:
|
||||
|
||||
- `PersistentChatContextProvider` attaches a `TraceFlushListener` to normal
|
||||
chat turns.
|
||||
- `AgentSessionRuntimeService` attaches the same listener to
|
||||
`agent-session:${sessionId}` turns, including queued follow-up turns.
|
||||
- On the topic-level terminal event (`done`, `paused`, or `error`),
|
||||
`TraceFlushListener` calls `SpanCacheService.saveSpans(topicId)`.
|
||||
- Flush errors are logged as warnings and do not affect message completion.
|
||||
|
||||
The trace pane and trace window are viewers only. They read spans through
|
||||
`SpanCacheService.getSpans(...)`, which falls back to the JSONL history file
|
||||
when the in-memory store has already been cleared.
|
||||
|
||||
## AdapterTracer
|
||||
|
||||
`src/main/ai/observability/adapters/aiSdk/adapterTracer.ts` wraps the OTel `Tracer` returned
|
||||
by the global provider. On every `startSpan` / `startActiveSpan` it:
|
||||
|
||||
1. Patches `span.end()` to also call `AiSdkSpanAdapter.convertToSpanEntity(...)`
|
||||
and hand the result to the observability sink registry.
|
||||
2. Stamps `trace.topicId` and `trace.modelName` so spans are queryable
|
||||
per topic in the dev-tools UI.
|
||||
|
||||
`AdapterTracer` is intentionally only for AI SDK child spans:
|
||||
|
||||
- `buildTelemetry` (`runtime/aiSdk/params/buildTelemetry.ts`) — passed to AI
|
||||
SDK as `experimental_telemetry.tracer`. Captures every AI SDK auto-span.
|
||||
Returns `undefined` (no telemetry, no tracer) when there is no `topicId`
|
||||
or developer mode is off — see below.
|
||||
|
||||
## AiSdkSpanAdapter
|
||||
|
||||
`src/main/ai/observability/adapters/aiSdk/aiSdkSpanAdapter.ts` converts an OTel span into the
|
||||
shape the dev-tools UI consumes:
|
||||
|
||||
- Reads span name, attributes, events, status, links.
|
||||
- Recovers AI SDK's hierarchical attribute conventions:
|
||||
`ai.xxx` is a level, `ai.xxx.yyy` is a sub-level under it.
|
||||
- Normalises usage attributes across the base and LLM spans: input from
|
||||
`ai.usage.promptTokens` / `gen_ai.usage.input_tokens`, output from
|
||||
`ai.usage.completionTokens` / `gen_ai.usage.output_tokens`. (There is no
|
||||
reasoning-token extraction.)
|
||||
|
||||
Claude Code Agent SDK spans do not go through `AiSdkSpanAdapter`; they are
|
||||
converted by `src/main/ai/observability/adapters/claudeCode/ClaudeCodeOtlpAdapter.ts`.
|
||||
|
||||
## Sensitive data capture & redaction
|
||||
|
||||
> Cross-referenced from `ClaudeCodeTraceBridgeService.prepareTrace`.
|
||||
|
||||
The Claude Code OTLP bridge runs **only when developer mode is enabled**. When
|
||||
it does, it intentionally turns on verbose Claude Code telemetry:
|
||||
|
||||
- `OTEL_LOG_USER_PROMPTS` — user prompt text
|
||||
- `OTEL_LOG_TOOL_DETAILS` / `OTEL_LOG_TOOL_CONTENT` — tool calls and their content
|
||||
- `OTEL_LOG_RAW_API_BODIES` — raw API request/response bodies
|
||||
|
||||
These payloads land in span attributes that `SpanCacheService` persists as
|
||||
**plaintext JSONL trace files on disk**, so a trace can contain secrets
|
||||
(authorization headers, API keys embedded in raw bodies) alongside the prompt
|
||||
and tool content.
|
||||
|
||||
**Redaction is deliberately not done.** Stripping secrets would mean parsing
|
||||
arbitrary OTLP attribute structures across the ingest path and would risk
|
||||
dropping legitimate trace data. The accepted tradeoff is that capture is
|
||||
**local-only and developer-gated**; turning that into a redaction/threat-model
|
||||
guarantee is a deferred decision. Treat exported trace files as sensitive.
|
||||
|
||||
## Where it shows up in the UI
|
||||
|
||||
Dev mode only. The dev-tools span viewer reads from the local observability
|
||||
projection (`SpanCacheService`) and renders the per-topic tree. Outside
|
||||
developer mode `buildTelemetry` returns `undefined`, so **no tracer is
|
||||
attached at all** and the AI SDK emits no spans — there is nothing to
|
||||
project. Viewer mount, tab close, and window close are not part of the
|
||||
trace persistence lifecycle.
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/main/ai/observability/`
|
||||
- Span projection: `src/main/ai/observability/cache/SpanCacheService.ts`
|
||||
- AI SDK telemetry docs: https://ai-sdk.dev/docs/reference/ai-sdk-core/telemetry
|
||||
123
docs/references/ai/params-pipeline.md
Normal file
123
docs/references/ai/params-pipeline.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Params Pipeline
|
||||
|
||||
## What it is
|
||||
|
||||
`buildAgentParams` (`src/main/ai/runtime/aiSdk/params/buildAgentParams.ts`) is the
|
||||
single function that turns a (request, provider, model, assistant) tuple
|
||||
into everything `Agent.stream()` needs:
|
||||
|
||||
```ts
|
||||
interface BuiltAgentParams {
|
||||
sdkConfig: SdkConfig // providerId + providerSettings + modelId
|
||||
tools: ToolSet | undefined // active + meta-tools after defer
|
||||
plugins: AiPlugin<any, any>[] // model-adapter plugins (ordered)
|
||||
system: string | undefined // assembled system prompt
|
||||
options: AgentOptions // headers, providerOptions, stopWhen, repair, telemetry
|
||||
hookParts: ReadonlyArray<Partial<AgentLoopHooks>>
|
||||
}
|
||||
```
|
||||
|
||||
It is a pure async function — no class, no shared state. Callers (chat,
|
||||
agent session, translate, prompt-only) shape their own `AiBaseRequest` and
|
||||
hand it in.
|
||||
|
||||
## RequestFeature
|
||||
|
||||
The composition unit is `RequestFeature`
|
||||
(`src/main/ai/runtime/aiSdk/params/feature.ts`):
|
||||
|
||||
```ts
|
||||
interface RequestFeature {
|
||||
readonly name: string
|
||||
applies?(scope: RequestScope): boolean
|
||||
contributeModelAdapters?(scope: RequestScope): AiPlugin<any, any>[]
|
||||
contributeHooks?(scope: RequestScope): Partial<AgentLoopHooks>
|
||||
}
|
||||
```
|
||||
|
||||
`collectFromFeatures(scope, features)` calls each feature's `applies`
|
||||
(default `true`), then collects its model adapters and hook parts. The
|
||||
result feeds `plugins` and `hookParts` in `BuiltAgentParams`.
|
||||
|
||||
Order matters because AI SDK plugin order is significant. The list lives
|
||||
in `src/main/ai/runtime/aiSdk/params/features/index.ts`:
|
||||
|
||||
```ts
|
||||
export const INTERNAL_FEATURES = [
|
||||
devtoolsFeature,
|
||||
gatewayUsageNormalizeFeature,
|
||||
modelParamsFeature,
|
||||
pdfCompatibilityFeature, // must run before anthropicCacheFeature
|
||||
reasoningExtractionFeature, // must run before simulateStreamingFeature
|
||||
simulateStreamingFeature,
|
||||
anthropicCacheFeature,
|
||||
anthropicHeadersFeature,
|
||||
openrouterReasoningFeature,
|
||||
noThinkFeature,
|
||||
qwenThinkingFeature,
|
||||
skipGeminiThoughtSignatureFeature,
|
||||
providerWebSearchFeature,
|
||||
providerUrlContextFeature
|
||||
]
|
||||
```
|
||||
|
||||
Callers can append per-request `extraFeatures`; those run after the
|
||||
internal set. (AiService's analytics is *not* one of these — it is injected
|
||||
separately as a `hookParts` entry, not a `RequestFeature`.)
|
||||
|
||||
## RequestScope
|
||||
|
||||
All features receive the same read-only scope object built in
|
||||
`buildAgentParams`:
|
||||
|
||||
```ts
|
||||
interface RequestScope extends ToolApplyScope {
|
||||
request, signal, registry, assistant, model, provider,
|
||||
capabilities, // resolveCapabilities — see capabilities.ts
|
||||
sdkConfig, endpointType, aiSdkProviderId,
|
||||
requestContext, // RequestContext for tool execute()
|
||||
mcpToolIds
|
||||
}
|
||||
```
|
||||
|
||||
Features must never mutate the scope. The scope IS shared across all
|
||||
features for a single request, so any added field becomes part of the
|
||||
contract — keep it minimal.
|
||||
|
||||
## Pipeline order
|
||||
|
||||
```
|
||||
buildAgentParams(input)
|
||||
├─ resolveSdkConfig → providerToAiSdkConfig + modelId
|
||||
├─ canModelConsumeTools? → resolveTools (registry sync + defer)
|
||||
│ └─ syncMcpToolsToRegistry (only servers owning a selected tool)
|
||||
│ └─ registry.selectActive (per-entry applies)
|
||||
│ └─ applyDeferExposition (defer pool → meta-tools + system section)
|
||||
├─ resolveCapabilities → enableWebSearch / enableUrlContext / …
|
||||
├─ resolveEffectiveEndpoint → endpointType (model > provider default)
|
||||
├─ resolveAiSdkProviderId → adapter-family routing (see adapter-family.md)
|
||||
├─ collectFromFeatures → plugins + hookParts
|
||||
├─ assembleSystemPrompt → assistant prompt + deferred-tools header
|
||||
└─ buildAgentOptions → providerOptions + customParameters split
|
||||
+ headers + stopWhen + repair + telemetry
|
||||
```
|
||||
|
||||
## customParameters split
|
||||
|
||||
User-supplied `assistant.customParameters` may contain AI-SDK standard
|
||||
params (temperature, topP, etc.) **and** provider-scoped overrides.
|
||||
`extractAiSdkStandardParams` separates them; standard params land on the
|
||||
top-level `AgentOptions` (AI SDK forwards them to the model), provider
|
||||
params merge into `providerOptions[aiSdkProviderId]` (after a
|
||||
`mergeCustomProviderParameters` pass that respects existing capability
|
||||
options).
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/main/ai/runtime/aiSdk/params/`
|
||||
- Tests: `src/main/ai/runtime/aiSdk/params/__tests__/` (`assembleSystemPrompt`,
|
||||
`collectFromFeatures`, `composeHooks`),
|
||||
`src/main/ai/runtime/aiSdk/params/features/__tests__/`
|
||||
- Tool defer: [Tool Registry](./tool-registry.md)
|
||||
- Endpoint routing: [Adapter Family](./adapter-family.md)
|
||||
- Hooks: [Agent Loop](./agent-loop.md)
|
||||
137
docs/references/ai/provider-resolution.md
Normal file
137
docs/references/ai/provider-resolution.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# Provider Resolution
|
||||
|
||||
## The problem this solves
|
||||
|
||||
A request needs to know which `@ai-sdk/*` package to import, with which
|
||||
settings, hitting which URL. Three pieces of state determine that:
|
||||
|
||||
| Field | Lives on | Example |
|
||||
|---|---|---|
|
||||
| `provider.id` | `Provider` row | `minimax`, `silicon`, `my-relay` |
|
||||
| `endpointType` | `model.endpointTypes[0]` or `provider.defaultChatEndpoint` | `openai-chat-completions`, `anthropic-messages` |
|
||||
| `adapterFamily` | `provider.endpointConfigs[endpointType].adapterFamily` | `openai-compatible`, `anthropic`, `azure-responses` |
|
||||
|
||||
`adapterFamily` is the actual SDK selector. `provider.id` is the user-facing
|
||||
identity. `endpointType` is the protocol family. The mapping is written
|
||||
once at provider-creation time; runtime resolution is read-only.
|
||||
|
||||
See [Adapter Family](./adapter-family.md) for the full design.
|
||||
|
||||
## Resolver
|
||||
|
||||
`src/main/ai/provider/endpoint.ts` exposes three pure helpers:
|
||||
|
||||
```ts
|
||||
resolveEffectiveEndpoint(provider, model): { endpointType, baseUrl }
|
||||
resolveProviderVariant(baseProviderId, endpointType): AppProviderId
|
||||
resolveAiSdkProviderId(provider, endpointType): AppProviderId
|
||||
```
|
||||
|
||||
`resolveAiSdkProviderId` is the runtime hot-path entry. It reads
|
||||
`provider.endpointConfigs[endpointType].adapterFamily`, applies the
|
||||
variant suffix if the endpoint type has one, falls back to
|
||||
`openai-compatible` when no family is set.
|
||||
|
||||
```ts
|
||||
// Full resolver — 6 lines
|
||||
export function resolveAiSdkProviderId(provider, endpointType) {
|
||||
const adapterFamily = endpointType
|
||||
? provider.endpointConfigs?.[endpointType]?.adapterFamily
|
||||
: undefined
|
||||
if (adapterFamily && adapterFamily in appProviderIds) {
|
||||
return resolveProviderVariant(appProviderIds[adapterFamily], endpointType)
|
||||
}
|
||||
return appProviderIds['openai-compatible']
|
||||
}
|
||||
```
|
||||
|
||||
## Variants
|
||||
|
||||
Some bases expose variant ids (a different endpoint on the same base).
|
||||
`resolveProviderVariant` knows two suffix rules and applies one only when
|
||||
the resulting `<base>-<suffix>` id is actually registered — otherwise it
|
||||
returns the base unchanged:
|
||||
|
||||
| Endpoint type | Suffix tried |
|
||||
|---|---|
|
||||
| `openai-chat-completions`, `ollama-chat` | `-chat` |
|
||||
| `openai-responses` | `-responses` |
|
||||
|
||||
Variants registered today (declared in each provider extension's
|
||||
`variants` array, `packages/aiCore/src/core/providers/core/initialization.ts`):
|
||||
|
||||
| Base | Variant id(s) |
|
||||
|---|---|
|
||||
| `openai` | `openai-chat` (the base `openai` is itself the Responses API) |
|
||||
| `azure` | `azure-responses`, `azure-anthropic` |
|
||||
| `xai` | `xai-responses` |
|
||||
| `cherryin` | `cherryin-chat` |
|
||||
|
||||
`ollama` has no registered variant, so an `ollama-chat` endpoint resolves
|
||||
to the base `ollama`. Likewise there is **no `openai-responses` variant**
|
||||
(the base already is). `azure-anthropic` is not reached through the suffix
|
||||
rule — it is selected inside `buildAzureConfig` when the model is a Claude
|
||||
model (see below). `resolveProviderVariant(baseId, endpointType)` is
|
||||
idempotent when the base id is already a variant.
|
||||
|
||||
## Provider config
|
||||
|
||||
`providerToAiSdkConfig(provider, model)`
|
||||
(`src/main/ai/provider/config.ts`) returns
|
||||
`{ providerId: AppProviderId, providerSettings: AppProviderSettingsMap[id] }`.
|
||||
It calls `resolveAiSdkProviderId` internally, then dispatches through an
|
||||
ordered `{ match, build }` table to build the provider-specific settings
|
||||
object (apiKey, baseURL, organization, headers, ...). There is **no
|
||||
"gateway" branch**.
|
||||
|
||||
The builder table (`config.ts`, first match wins):
|
||||
|
||||
| Match | Builder | Notes |
|
||||
|---|---|---|
|
||||
| `id === copilot` | `buildCopilotConfig` | async — fetches a Copilot token |
|
||||
| `id === 'cherryai'` | `buildCherryAIConfig` | |
|
||||
| `isOllamaProvider` | `buildOllamaConfig` | |
|
||||
| `isAzureOpenAIProvider` | `buildAzureConfig` | returns `azure` / `azure-responses` / `azure-anthropic` (Claude on Azure) |
|
||||
| `id === 'bedrock'` | `buildBedrockConfig` | |
|
||||
| `id === 'google-vertex'` | `buildVertexConfig` | returns `google-vertex` or `google-vertex-anthropic` for Claude; leaves `baseURL` undefined when no host is configured so the SDK derives the aiplatform host |
|
||||
| `provider.id === 'cherryin'` | `buildCherryinConfig` | matches the **provider id**, not the resolved variant — the default chat endpoint resolves to `cherryin-chat`, so an `id === 'cherryin'` check never fires; async — resolves relay base URLs |
|
||||
| `id === 'newapi'` | `buildNewApiConfig` | |
|
||||
| `id === 'aihubmix'` | `buildAiHubMixConfig` | |
|
||||
| _(no match)_ | `buildGenericProviderConfig` / `buildOpenAICompatibleConfig` | generic fallback |
|
||||
|
||||
Several builders are `async` (Copilot token, CherryIN relay URLs), which is
|
||||
why `providerToAiSdkConfig` returns a promise.
|
||||
|
||||
## Custom providers
|
||||
|
||||
`src/main/ai/provider/custom/`:
|
||||
|
||||
- **aihubmix** — multi-vendor relay. `provider.id='aihubmix'` but each
|
||||
model carries `model.provider='aihubmix.<vendor>'`; the registry's
|
||||
aggregator fallback uses the suffix to pick the right `toolFactory`.
|
||||
- **newapi** — same shape, different relay.
|
||||
|
||||
Both register through `ProviderExtension.create(...)` with their own
|
||||
`AppProviderSettings` shape.
|
||||
|
||||
## Provider extensions
|
||||
|
||||
`src/main/ai/provider/extensions/index.ts` registers every
|
||||
`@ai-sdk/*` package Cherry uses with `ProviderExtension.create`. Each
|
||||
extension declares:
|
||||
|
||||
- `name` (the `AppProviderId` for the base)
|
||||
- `aliases` (alternate ids that normalize to `name`)
|
||||
- `variants` (suffix entries — see above)
|
||||
- `create` (the SDK's factory)
|
||||
- `toolFactories` (per-capability factory functions for `webSearch` /
|
||||
`urlContext` etc.)
|
||||
- `supportsImageGeneration` (boolean flag)
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/main/ai/provider/`
|
||||
- Tests: `provider/__tests__/endpoint.test.ts` (54 cases)
|
||||
- Migration of legacy provider rows: `src/main/data/migration/v2/migrators/mappings/ProviderModelMappings.ts`
|
||||
- Catalog (new installs): `packages/provider-registry/data/providers.json`
|
||||
- Design: [Adapter Family](./adapter-family.md)
|
||||
838
docs/references/ai/stream-manager.md
Normal file
838
docs/references/ai/stream-manager.md
Normal file
@@ -0,0 +1,838 @@
|
||||
# AiStreamManager
|
||||
|
||||
## What it is
|
||||
|
||||
`AiStreamManager` is the Main-process **active-stream registry** and the
|
||||
broker for every stream event. It owns the full life cycle of an AI
|
||||
streaming reply — from `sendMessages` until the assistant turn finishes
|
||||
persisting — including multicast fan-out, reconnect, abort, abort-and-restart
|
||||
steering, and persistence triggering.
|
||||
|
||||
The renderer no longer holds a direct reference to the stream. Closing a
|
||||
window does not abort the stream; it continues on Main and persists
|
||||
normally. When the user returns, `attach` re-subscribes and the
|
||||
manager replays any chunks that landed in between.
|
||||
|
||||
**Key: `topicId`.** A topic has at most one active stream at a time;
|
||||
"streaming" is one phase of a topic, and every subscriber on a topic is
|
||||
equal — there is no "owner" window.
|
||||
|
||||
## Why it exists
|
||||
|
||||
v1 ran the stream lifecycle, fan-out, and persistence on the **renderer**,
|
||||
which produced three structural bug classes:
|
||||
|
||||
- **Window-bound lifecycle** — unmounting the chat (topic switch, window
|
||||
close, route change) cancelled the transport stream, which aborted the
|
||||
upstream request and dropped the in-flight reply.
|
||||
- **No reconnect** — `reconnectToStream()` always returned `null`, so
|
||||
returning to a topic lost live progress until the row hit the DB.
|
||||
- **Renderer-owned persistence** — the DB write lived in the renderer, so a
|
||||
crash/close between stream-end and commit lost the reply.
|
||||
|
||||
**Goal:** move stream lifecycle, multicast fan-out, and persistence to Main;
|
||||
the renderer's only job is rendering chunks. The sections below are the
|
||||
reference for that Main-side design.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────── Renderer ────────────────────────────────────┐
|
||||
│ │
|
||||
│ useChat({ id: topicId, transport: IpcChatTransport }) │
|
||||
│ ├─ sendMessages → Ai_Stream_Open (topicId, trigger, userMessageParts, …)
|
||||
│ ├─ reconnectToStream → Ai_Stream_Attach ({ topicId }) │
|
||||
│ └─ abort signal → Ai_Stream_Abort ({ topicId }) │
|
||||
│ │
|
||||
│ History: useQuery('/topics/:id/messages') │
|
||||
│ Topic-level state: useTopicStreamStatus → shared cache │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
↕ IPC (all keyed by topicId)
|
||||
┌──────────────── Main ────────────────────────────────────────┐
|
||||
│ │
|
||||
│ dispatchStreamRequest(manager, subscriber, req) │
|
||||
│ │ pick first ChatContextProvider whose canHandle matches │
|
||||
│ │ provider.prepareDispatch(subscriber, req, ctx) │
|
||||
│ └ manager.send(prepared) │
|
||||
│ │
|
||||
│ AiStreamManager │
|
||||
│ ┌────────────────────────────────────────────────────────┐ │
|
||||
│ │ activeStreams: Map<topicId, ActiveStream> │ │
|
||||
│ │ listeners: Map<listenerId, StreamListener> │ │
|
||||
│ │ executions: Map<modelId, StreamExecution> │ │
|
||||
│ │ ├─ abortController / status │ │
|
||||
│ │ └─ buffer (ring) + droppedChunks │ │
|
||||
│ │ lifecycle: StreamLifecycle (chat or prompt) │ │
|
||||
│ └────────────────────────────────────────────────────────┘ │
|
||||
│ ↓ createAndLaunchExecution → runExecutionLoop │
|
||||
│ AiService.streamText(request) → ReadableStream<UIMessageChunk> │
|
||||
│ ↓ pipeStreamLoop (tees: broadcast + readUIMessageStream) │
|
||||
│ │
|
||||
│ terminal → dispatchToListeners → every StreamListener: │
|
||||
│ WebContentsListener → wc.send(Ai_StreamDone) │
|
||||
│ PersistenceListener → PersistenceBackend.persistAssistant
|
||||
│ • MessageServiceBackend (SQLite tree) │
|
||||
│ • TemporaryChatBackend (in-memory) │
|
||||
│ • AgentSessionMessageBackend (agent-session DB) │
|
||||
│ • TranslationBackend (translate row) │
|
||||
│ TraceFlushListener → SpanCacheService.saveSpans(topicId)
|
||||
│ ChannelAdapterListener → adapter.onStreamComplete │
|
||||
│ SseListener → res.write('[DONE]') │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Pub/sub model
|
||||
|
||||
The manager is a broker: one set of producers feeds it, one set of
|
||||
consumers subscribes. The system uses the observer pattern, and splits
|
||||
dispatch into two semantically distinct channels based on **payload
|
||||
volume × audience width**.
|
||||
|
||||
### Producers
|
||||
|
||||
| Producer | Events | Source |
|
||||
|---|---|---|
|
||||
| `StreamExecution` loop | `UIMessageChunk` (per-chunk delta) | `AiService.streamText`'s `ReadableStream` |
|
||||
| `AiStreamManager` (state machine) | topic-level status transitions | `send()` → `pending`, first chunk → `streaming`, three terminal handlers → `done` / `error` / `aborted`, `awaiting-approval` on `tool-approval-request` |
|
||||
|
||||
### Consumers
|
||||
|
||||
| Consumer | Events | Subscription |
|
||||
|---|---|---|
|
||||
| `WebContentsListener` | chunk + terminal | explicit `attach` → `ActiveStream.listeners` |
|
||||
| `PersistenceListener` | terminal | built by the provider and added in `send()` |
|
||||
| `TraceFlushListener` | terminal | built by chat / agent-session turn owners and added in `send()` |
|
||||
| `ChannelAdapterListener` / `SseListener` | chunk + terminal | caller injects into `send()`'s `listeners` |
|
||||
| UI indirect consumers (sidebar indicators, …) | topic status | `useSharedCache('topic.stream.statuses.${topicId}')` |
|
||||
|
||||
### Two channels: targeted listener dispatch vs SharedCache mirror
|
||||
|
||||
| | Targeted listener dispatch | SharedCache mirror |
|
||||
|---|---|---|
|
||||
| Transport | `Ai_StreamChunk` / `Ai_StreamDone` / `Ai_StreamError` | `cacheService.setShared('topic.stream.statuses.${topicId}', …)` → built-in `Cache_Sync` broadcast |
|
||||
| Main-side registry | `ActiveStream.listeners: Map<listenerId, StreamListener>` | none — uses the generic `CacheService` infra |
|
||||
| Subscriber API | `attach` to register, explicit `detach` | `useSharedCache('topic.stream.statuses.${topicId}')` by topicId |
|
||||
| Per-event size | tens of bytes to KBs (10s/s) | tens of bytes (≤ 5 transitions per stream) |
|
||||
| Audience | narrow (one window per listener typically) | wide (every sidebar / indicator across all windows) |
|
||||
| Cost of irrelevant pushes | high (bandwidth + deserialization) | negligible |
|
||||
|
||||
### Channel selection rule
|
||||
|
||||
Choose by **consumer / producer fanout**:
|
||||
|
||||
- chunk stream: one execution produces it, only the window rendering
|
||||
that topic needs it → **targeted listener dispatch**, no irrelevant
|
||||
pushes.
|
||||
- topic status: one transition, every UI mirror wants it → **SharedCache**,
|
||||
reuse generic cache sync, no bespoke IPC.
|
||||
|
||||
### Rules that follow from the channel split
|
||||
|
||||
- **`Ai_Stream_Attach` is required.** The listener channel requires
|
||||
explicit consumer registration; `attach` is the entry point and also
|
||||
returns a compact replay to fill the "before I subscribed" gap.
|
||||
- **Bootstrap needs no extra IPC.** A new window pulls all shared cache
|
||||
entries via `Cache_GetAllShared` on mount; every
|
||||
`topic.stream.statuses.${topicId}` entry comes through without a
|
||||
bespoke snapshot IPC.
|
||||
- **Snapshot vs delta race.** Handled by the shared cache sync layer
|
||||
itself — initial pull and `Cache_Sync` delta share the Main-side
|
||||
source of truth; late arrivals overwrite stale state.
|
||||
- **Grace-period cleanup does NOT clear the SharedCache entry.** Terminal
|
||||
values (`done` / `aborted` / `error`) stay so renderer-side consumers
|
||||
(`useTopicDbRefreshOnTerminal`, `useChatWithHistory`, awaiting-approval
|
||||
indicators, sidebar badges) can observe them. The fulfilled-badge gate
|
||||
is a read-receipt: the entry's `lastCompletedAt` (bumped only on
|
||||
`done`) compared against `topic.stream.last_seen_completion.${topicId}`
|
||||
(cross-window shared cache, written when the user acknowledges).
|
||||
Memory tier — both reset on app restart.
|
||||
- **`PersistenceListener` placement.** Terminal-only consumer — doesn't
|
||||
need chunk bandwidth → not added via `attach`; the provider includes
|
||||
it in the `listeners` array passed to `send()`.
|
||||
- **`TraceFlushListener` placement.** Terminal-only consumer that flushes
|
||||
`SpanCacheService.saveSpans(topicId)` after a chat / agent turn completes.
|
||||
It belongs with the turn owner (`PersistentChatContextProvider` or
|
||||
`AgentSessionRuntimeService`), not inside `AiStreamManager` and not in
|
||||
trace viewer UI.
|
||||
|
||||
## File layout
|
||||
|
||||
```
|
||||
src/main/ai/
|
||||
├── AiService.ts lifecycle service: streamText + non-streaming IPC gateway
|
||||
└── runtime/aiSdk/
|
||||
└── Agent.ts single-pass `Agent.stream` wrapper (see Agent Loop)
|
||||
|
||||
src/main/ai/streamManager/
|
||||
├── AiStreamManager.ts the registry + execution loop + multicast
|
||||
├── pipeStreamLoop.ts shared chunk-pipe primitive (used by AiStreamManager.runExecutionLoop)
|
||||
├── buildCompactReplay.ts attach-time chunk compaction (merge text-delta / reasoning-delta)
|
||||
├── types.ts ActiveStream / StreamExecution / StreamListener / timings
|
||||
├── index.ts barrel
|
||||
│
|
||||
├── context/ per-topicId namespace dispatch
|
||||
│ ├── ChatContextProvider.ts interface + PreparedDispatch
|
||||
│ ├── dispatch.ts single manager.send entry; MainContinueConversationRequest
|
||||
│ ├── PersistentChatContextProvider.ts uuid topics → SQLite tree
|
||||
│ ├── TemporaryChatContextProvider.ts in-memory (TemporaryChatService)
|
||||
│ ├── AgentChatContextProvider.ts `agent-session:` → agents DB
|
||||
│ └── modelResolution.ts resolveModels / siblingsGroupId
|
||||
│
|
||||
├── lifecycle/ strategy: chat vs ad-hoc prompt
|
||||
│ ├── StreamLifecycle.ts interface
|
||||
│ ├── ChatStreamLifecycle.ts cross-window broadcast + 30 s grace period + attach
|
||||
│ ├── PromptStreamLifecycle.ts silent, no attach, immediate eviction
|
||||
│ └── index.ts barrel
|
||||
│
|
||||
├── listeners/
|
||||
│ ├── WebContentsListener.ts chunks → renderer windows
|
||||
│ ├── PersistenceListener.ts observer protocol + delegates to PersistenceBackend
|
||||
│ ├── TraceFlushListener.ts terminal trace-cache flush to local history
|
||||
│ ├── ChannelAdapterListener.ts text → Discord / Slack / Feishu
|
||||
│ └── SseListener.ts UIMessageChunk → SSE response (API server)
|
||||
│
|
||||
└── persistence/
|
||||
├── PersistenceBackend.ts strategy interface + statsFromTerminal projection
|
||||
└── backends/
|
||||
├── MessageServiceBackend.ts finalize a SQLite pending placeholder
|
||||
├── TemporaryChatBackend.ts append to in-memory topic
|
||||
└── TranslationBackend.ts attach `data-translation` part to a target message
|
||||
```
|
||||
|
||||
Agent session persistence is implemented under `agentSession/persistence`
|
||||
because it writes the agent-session domain tables.
|
||||
|
||||
## StreamListener interface
|
||||
|
||||
The manager treats every consumer through one interface; it dispatches
|
||||
each event by calling these methods uniformly:
|
||||
|
||||
```typescript
|
||||
interface StreamListener {
|
||||
readonly id: string
|
||||
onChunk(chunk: UIMessageChunk, sourceModelId?: UniqueModelId): void
|
||||
onDone(result: StreamDoneResult): void | Promise<void> // { finalMessage?, status: 'success', ... }
|
||||
onPaused(result: StreamPausedResult): void | Promise<void> // { finalMessage?, status: 'paused', ... }
|
||||
onError(result: StreamErrorResult): void | Promise<void> // { finalMessage?, error, status: 'error', ... }
|
||||
isAlive(): boolean
|
||||
}
|
||||
```
|
||||
|
||||
All three terminal shapes share the same `finalMessage?` field — the
|
||||
`UIMessage` accumulated by `readUIMessageStream` in the execution loop.
|
||||
Whether the stream ended naturally, was aborted, or errored, it's the
|
||||
same variable, only the stop point differs. Earlier designs called the
|
||||
error-path partial a `partialMessage`; this turned out to be just a
|
||||
`finalMessage` that ended early. Unifying the shape means
|
||||
`PersistenceBackend` needs one `persistAssistant` method, not separate
|
||||
write paths per status.
|
||||
|
||||
### Built-in implementations
|
||||
|
||||
| Listener | Role | id | isAlive |
|
||||
|---|---|---|---|
|
||||
| **WebContentsListener** | chunks → renderer window | `wc:${wc.id}:${topicId}` | `!wc.isDestroyed()` |
|
||||
| **PersistenceListener** | terminal write via strategy | `persistence:${backendKind}:${topicId}:${modelId ?? 'default'}` | always `true` |
|
||||
| **TraceFlushListener** | terminal trace-cache flush | `persistence:trace:${topicId}` | always `true` |
|
||||
| **ChannelAdapterListener** | text → IM platform | `channel:${channelId}:${chatId}` | `adapter.connected` |
|
||||
| **SseListener** | API-server SSE passthrough | `sse:${uuid}` | `!res.writableEnded` |
|
||||
|
||||
### Unified liveness policy
|
||||
|
||||
`AiStreamManager.dispatchToListeners` is the single funnel for terminal
|
||||
events (`onDone` / `onPaused` / `onError`). Per listener it:
|
||||
|
||||
- Calls `listener.isAlive()` before each broadcast — `false` removes the
|
||||
listener from `stream.listeners` (cleans up dead consumers).
|
||||
- Wraps each call in try/catch — one bad listener can't starve the rest.
|
||||
- Logs by event name + listener id for easy triage.
|
||||
|
||||
`onChunk` keeps a synchronous contract (the execution loop can't `await`
|
||||
a listener) so it inlines the loop instead of going through
|
||||
`dispatchToListeners`, but the dead-listener cleanup is the same.
|
||||
|
||||
### PersistenceListener — strategy pattern
|
||||
|
||||
One listener + four backends:
|
||||
|
||||
```typescript
|
||||
interface PersistenceBackend {
|
||||
readonly kind: string // "sqlite" | "temp" | "agents-db" | "translation"
|
||||
persistAssistant(input: {
|
||||
finalMessage?: CherryUIMessage
|
||||
status: 'success' | 'paused' | 'error'
|
||||
modelId?: UniqueModelId
|
||||
stats?: MessageStats
|
||||
}): Promise<void>
|
||||
afterPersist?(finalMessage: CherryUIMessage): Promise<void>
|
||||
}
|
||||
```
|
||||
|
||||
Backends expose **one** write method; the three statuses share its
|
||||
shape. On the `error` branch, `PersistenceListener` folds the
|
||||
`SerializedError` into a trailing `data-error` part on `finalMessage.parts`
|
||||
and then calls `persistAssistant({ status: 'error' })`, so backends never
|
||||
have to know how to encode an error into a UIMessage — they just write.
|
||||
|
||||
The listener owns the observer protocol: filter by `modelId`
|
||||
(multi-model topics have one listener per execution), merge the error
|
||||
part exactly once, swallow exceptions so they don't break downstream
|
||||
dispatch, fire `afterPersist` only when `status === 'success'` and
|
||||
`finalMessage` is present (best-effort). Adding a fifth storage path
|
||||
(e.g. an outbox) is a 60-line backend, no listener boilerplate to copy.
|
||||
|
||||
## ActiveStream & StreamExecution
|
||||
|
||||
```typescript
|
||||
interface ActiveStream {
|
||||
topicId: string
|
||||
executions: Map<UniqueModelId, StreamExecution> // 1 entry single-model, N multi-model
|
||||
listeners: Map<string, StreamListener> // shared across executions
|
||||
// 'pending' on creation; flips to 'streaming' on first chunk; derived
|
||||
// from executions on terminal (done / aborted / error /
|
||||
// awaiting-approval).
|
||||
status: TopicStreamStatus
|
||||
isMultiModel: boolean // fixed at create; tags onChunk's sourceModelId
|
||||
lifecycle: StreamLifecycle // chat or prompt strategy
|
||||
expiresAt?: number
|
||||
cleanupTimer?: ReturnType<typeof setTimeout>
|
||||
}
|
||||
|
||||
interface StreamExecution {
|
||||
modelId: UniqueModelId
|
||||
anchorMessageId?: string // placeholder id for submit/regen, anchor id for continue
|
||||
abortController: AbortController
|
||||
status: 'streaming' | 'done' | 'error' | 'aborted'
|
||||
|
||||
// Per-execution ring buffer for reconnect replay. Hitting
|
||||
// `maxBufferChunks` drops the oldest entry and bumps `droppedChunks`.
|
||||
// Independent buffers prevent a chatty model from evicting a slower
|
||||
// model's replay (a shared buffer would).
|
||||
buffer: StreamChunkPayload[]
|
||||
droppedChunks: number
|
||||
|
||||
finalMessage?: CherryUIMessage
|
||||
|
||||
// Set the moment a `tool-approval-request` chunk arrives, cleared on
|
||||
// response. Read by `resolveTerminalStatus` to surface
|
||||
// `awaiting-approval` on the topic.
|
||||
awaitingApproval?: boolean
|
||||
|
||||
error?: SerializedError
|
||||
siblingsGroupId?: number
|
||||
loopPromise: Promise<void> // awaited by onStop for graceful shutdown
|
||||
|
||||
// Transport-side timings owned by the execution loop — chunk-shape-agnostic.
|
||||
// Semantic timings (firstTextAt / reasoning*) live on the listener
|
||||
// that cares; see "Stats composition" below.
|
||||
timings: TransportTimings
|
||||
|
||||
// OTel root span set as active context around runExecutionLoop so
|
||||
// AI SDK spans become children. Created by the context provider.
|
||||
rootSpan?: Span
|
||||
}
|
||||
|
||||
interface TransportTimings {
|
||||
readonly startedAt: number // execution loop entry
|
||||
completedAt?: number // execution loop exit (both try and catch paths)
|
||||
}
|
||||
|
||||
interface SemanticTimings {
|
||||
firstTextAt?: number // first text-delta chunk (TTFT endpoint)
|
||||
reasoningStartedAt?: number // first reasoning-* chunk
|
||||
reasoningEndedAt?: number // first non-reasoning chunk after reasoning
|
||||
}
|
||||
```
|
||||
|
||||
Topic-level status is derived from executions, with `'pending'` as the
|
||||
initial pre-first-chunk window:
|
||||
|
||||
- Created (`send()` returned) → `'pending'`
|
||||
- Any execution emits its first chunk → `'streaming'`
|
||||
- All terminal, all `done` → `'done'`
|
||||
- All terminal, all `aborted` → `'aborted'`
|
||||
- Has `error`, none `streaming` → `'error'`
|
||||
- Any execution still has `awaitingApproval` true on a terminal topic → `'awaiting-approval'`
|
||||
|
||||
`pending → streaming` is a one-time transition (first chunk anywhere).
|
||||
The terminal status is derived once when the last execution terminates.
|
||||
|
||||
### Stats composition — tokens + timings → MessageStats
|
||||
|
||||
**Ownership** (key invariant: manager does not peek at chunk payloads):
|
||||
|
||||
| Source field | Owner | Collected at |
|
||||
|---|---|---|
|
||||
| `TransportTimings.startedAt` | `AiStreamManager` | `createAndLaunchExecution` |
|
||||
| `TransportTimings.completedAt` | `AiStreamManager` | `pipeStreamLoop`'s `broadcastCompletedAt` |
|
||||
| `SemanticTimings.firstTextAt` | `PersistenceListener` | own `onChunk`, first `text-delta` |
|
||||
| `SemanticTimings.reasoning*` | `PersistenceListener` | own `onChunk`, observing `reasoning-*` boundaries |
|
||||
| Token metadata | `agentLoop` usage observer | `finish` chunk projects AI SDK `LanguageModelUsage` → `CherryUIMessageMetadata` |
|
||||
|
||||
The manager is chunk-shape-agnostic — multicast, reconnect, abort,
|
||||
abort-and-restart steering, persistence-triggering, never "what is text /
|
||||
what is reasoning". AI SDK chunk type changes (vNext renames) only touch
|
||||
`PersistenceListener`; the manager stays stable.
|
||||
|
||||
**Final projection.** `statsFromTerminal(finalMessage, mergedTimings)`
|
||||
is one function; the listener merges its `SemanticTimings` with
|
||||
`result.timings` (transport) before calling it:
|
||||
|
||||
```typescript
|
||||
// inside PersistenceListener
|
||||
const mergedTimings = { ...result.timings, ...this.semanticTimings }
|
||||
const stats = statsFromTerminal(finalMessage, mergedTimings)
|
||||
await this.opts.backend.persistAssistant({ finalMessage, status, modelId, stats })
|
||||
```
|
||||
|
||||
Projected `MessageStats` fields:
|
||||
|
||||
| Field | Source |
|
||||
|---|---|
|
||||
| `totalTokens / promptTokens / completionTokens / thoughtsTokens` | `finalMessage.metadata.*` |
|
||||
| `timeFirstTokenMs` | `round(firstTextAt - startedAt)` |
|
||||
| `timeCompletionMs` | `round(completedAt - startedAt)` |
|
||||
| `timeThinkingMs` | **not projected** — wall-clock `reasoningEndedAt - reasoningStartedAt` can include interleaved tool exec; see the `TODO(message-stats-redesign)` note in `PersistenceBackend.ts` |
|
||||
|
||||
Backends never derive stats themselves; they just write `input.stats`.
|
||||
One projection path, four backends, no duplication.
|
||||
|
||||
## Public API
|
||||
|
||||
```typescript
|
||||
class AiStreamManager {
|
||||
// Lifecycle container invokes with no args (DEFAULT_CONFIG); tests can
|
||||
// override `gracePeriodMs`, `backgroundMode`, `maxBufferChunks`.
|
||||
constructor(config?: Partial<AiStreamManagerConfig>)
|
||||
|
||||
readonly chatLifecycle: StreamLifecycle
|
||||
|
||||
// ── Single dispatch entry ─────────────────────────────────────────
|
||||
// Live topic → inject (agent-session only: upsert listeners onto the
|
||||
// running stream, models ignored). Otherwise → start (evict any
|
||||
// grace-period stream, launch one execution per `models` entry). A live
|
||||
// chat topic never reaches the inject branch — `dispatch` restarts it via
|
||||
// `abortAndAwait` first. Multi-model is detected from `models.length > 1`.
|
||||
send(input: SendInput): SendResult
|
||||
|
||||
// ── Ad-hoc prompt stream (translate / topic-naming / model probes)
|
||||
// Bypasses the chat dispatcher; uses promptStreamLifecycle (silent, no
|
||||
// attach, immediate eviction).
|
||||
streamPrompt(input: {
|
||||
streamId: string // doubles as topicId
|
||||
uniqueModelId: UniqueModelId
|
||||
prompt?: string
|
||||
messages?: CherryUIMessage[]
|
||||
listener: StreamListener | StreamListener[]
|
||||
}): SendResult
|
||||
|
||||
// ── Subscription management ───────────────────────────────────────
|
||||
attach(sender: WebContents, req: { topicId }): AiStreamAttachResponse
|
||||
detach(sender: WebContents, req: { topicId }): void
|
||||
addListener(topicId: string, listener: StreamListener): boolean
|
||||
removeListener(topicId: string, listenerId: string): void
|
||||
|
||||
// ── Control ───────────────────────────────────────────────────────
|
||||
abort(topicId: string, reason: string): void
|
||||
// Abort a live turn and await its executions settling (partial persists as
|
||||
// `paused`) before evicting — used by `dispatch` to restart a chat turn.
|
||||
abortAndAwait(topicId: string, reason: string): Promise<void>
|
||||
hasLiveStream(topicId: string): boolean
|
||||
|
||||
// ── Execution-loop callbacks (driven internally; public for tests) ─
|
||||
onChunk(topicId, modelId, chunk): void
|
||||
onExecutionDone(topicId, modelId): Promise<void>
|
||||
onExecutionPaused(topicId, modelId): Promise<void>
|
||||
onExecutionError(topicId, modelId, error): Promise<void>
|
||||
|
||||
// ── Inspection (read-only snapshot) ───────────────────────────────
|
||||
inspect(topicId: string): TopicSnapshot | undefined
|
||||
}
|
||||
```
|
||||
|
||||
### `send` contract
|
||||
|
||||
```typescript
|
||||
interface SendInput {
|
||||
topicId: string
|
||||
models: ReadonlyArray<{ modelId: UniqueModelId; request: AiStreamRequest; rootSpan?: Span }>
|
||||
listeners: StreamListener[]
|
||||
userMessage?: Message // persisted user row; not consumed by send() — callers' bookkeeping
|
||||
siblingsGroupId?: number
|
||||
lifecycle?: StreamLifecycle // omit → chatLifecycle; streamPrompt passes promptStreamLifecycle
|
||||
}
|
||||
|
||||
interface SendResult {
|
||||
mode: 'started' | 'injected'
|
||||
executionIds: UniqueModelId[] // started → fresh ids; injected → already running
|
||||
}
|
||||
```
|
||||
|
||||
- **injected**: topic has a live stream (`pending` or `streaming`) →
|
||||
`models` is ignored and `listeners` upsert by id; **no message is
|
||||
injected**. Only agent-session topics reach this branch (the dispatcher
|
||||
aborts+restarts a live chat topic before calling `send()`); the
|
||||
agent-session follow-up was already enqueued on the session's
|
||||
`pendingTurns` by its provider.
|
||||
- **started**: topic is idle or grace-period (terminal) → any leftover
|
||||
grace-period stream is evicted, a new `ActiveStream` is created with
|
||||
`isMultiModel = models.length > 1`, one execution launched per model.
|
||||
|
||||
`isMultiModel` is not an input — it's derived from `models.length`.
|
||||
|
||||
### Execution loop — `runExecutionLoop` + `pipeStreamLoop`
|
||||
|
||||
Each execution runs an independent loop that bridges "the single
|
||||
`ReadableStream` from AI SDK" to "what the manager has to do":
|
||||
broadcast to listeners, buffer for reconnect, and accumulate a
|
||||
persistable `finalMessage`.
|
||||
|
||||
**Step 1 — get the raw chunk stream.**
|
||||
|
||||
```typescript
|
||||
const stream: ReadableStream<UIMessageChunk> = await aiService.streamText({
|
||||
...request,
|
||||
requestOptions: { ...request.requestOptions, signal }
|
||||
})
|
||||
```
|
||||
|
||||
`streamText` returns AI SDK's raw chunk stream. `signal` comes from
|
||||
`StreamExecution.abortController`; `abort()` triggers it.
|
||||
|
||||
**Step 2 — wrap with `withIdleTimeout`.** Resets per chunk; on idle
|
||||
timeout it aborts `exec.abortController`, which the upstream request is
|
||||
already wired to.
|
||||
|
||||
**Step 3 — `pipeStreamLoop` tees the chunk stream.**
|
||||
|
||||
`pipeStreamLoop` is the shared chunk-pipe primitive (the one
|
||||
`AiStreamManager.runExecutionLoop` uses). It `tee()`s the stream into two
|
||||
independent branches:
|
||||
|
||||
| Branch | Consumer | Purpose |
|
||||
|---|---|---|
|
||||
| Broadcast | `onChunk(topicId, modelId, chunk)` per chunk | Buffer into `exec.buffer` (ring), fan out to every listener |
|
||||
| Accumulator | `readUIMessageStream` | Each yielded snapshot is written to `exec.finalMessage`; at stream end it's the final message |
|
||||
|
||||
The accumulator reader is **not** cancelled directly on abort —
|
||||
`Agent.stream` honours the same signal upstream and propagates `done`
|
||||
through `tee()`, so the accumulator drains naturally. Cancelling the
|
||||
accumulator reader directly would race AI SDK's internal
|
||||
`controller.close()` and produce an `ERR_INVALID_STATE`
|
||||
unhandledRejection.
|
||||
|
||||
**Step 4 — terminal dispatch.**
|
||||
|
||||
| Exit path | Handler | Behaviour |
|
||||
|---|---|---|
|
||||
| Normal end | `onExecutionDone` | `exec.status = 'done'`, finalMessage persisted as `success` |
|
||||
| `signal.aborted` + `exec.status === 'aborted'` | `onExecutionPaused` | (Possibly partial) finalMessage persisted as `paused` |
|
||||
| `streamErrorText` (in-stream `error` chunk) | `onExecutionError` | Error part folded into finalMessage, persisted as `error` |
|
||||
| Pre-stream or broadcast throw | `onExecutionError` | Same — error part folded, persisted |
|
||||
|
||||
## Lifecycle strategy — chat vs prompt
|
||||
|
||||
The manager stays policy-free. Behaviour that differs between chat
|
||||
streams and one-shot ad-hoc prompts (translate, topic-naming, model
|
||||
probes) lives in `StreamLifecycle`:
|
||||
|
||||
```typescript
|
||||
interface StreamLifecycle {
|
||||
readonly name: string
|
||||
onCreated(stream): void // freshly registered
|
||||
onPromotedToStreaming(stream): void // first chunk
|
||||
onTerminal(stream): void // every isTopicDone
|
||||
canAttach(stream): boolean // gate for `attach`
|
||||
cleanup(stream, evict: () => void): void // when to remove from activeStreams
|
||||
}
|
||||
```
|
||||
|
||||
| | `ChatStreamLifecycle` | `PromptStreamLifecycle` |
|
||||
|---|---|---|
|
||||
| Status broadcast | writes `topic.stream.statuses.<topicId>` on `pending → streaming → terminal` (with `awaitingApprovalAnchors` derived from `exec.awaitingApproval`) | none |
|
||||
| `canAttach` | `true` | `false` |
|
||||
| `cleanup` | sets a `setTimeout(evict, gracePeriodMs)`; chat reconnects within 30 s | calls `evict()` immediately |
|
||||
|
||||
`send()` defaults to `chatLifecycle`; `streamPrompt()` passes
|
||||
`promptStreamLifecycle`.
|
||||
|
||||
## Multi-model
|
||||
|
||||
User mentions multiple models for one turn:
|
||||
|
||||
```
|
||||
User: "Explain quantum mechanics" @gpt-4o @claude-sonnet
|
||||
↓
|
||||
PersistentChatContextProvider.prepareDispatch
|
||||
├─ persist user message (tree node)
|
||||
├─ resolveModels → [gpt-4o, claude-sonnet]
|
||||
├─ siblingsGroupId = (monotonic counter)
|
||||
├─ create one pending assistant placeholder per model (SQLite)
|
||||
├─ build listeners: subscriber + 2 PersistenceListener (one per backend)
|
||||
├─ build models: 2 × { modelId, request, rootSpan }
|
||||
└─ return PreparedDispatch
|
||||
|
||||
dispatchStreamRequest → manager.send({ models, listeners, siblingsGroupId })
|
||||
│
|
||||
├─ create ActiveStream (isMultiModel = true, 2 executions)
|
||||
├─ launch one execution loop per model, each with its own
|
||||
│ ring buffer
|
||||
└─ return { mode: 'started', executionIds: [gpt-4o, claude-sonnet] }
|
||||
```
|
||||
|
||||
## Steering
|
||||
|
||||
Steering a chat turn is **abort-and-restart**, not mid-turn injection. When a
|
||||
new `Ai_Stream_Open` arrives for a topic that is still streaming,
|
||||
`dispatchStreamRequest` calls `manager.abortAndAwait(topicId, 'steer-restart')`:
|
||||
it aborts every execution, awaits their loops settling (each partial persists
|
||||
as `paused` via the normal terminal path), and evicts the stream — then the
|
||||
following `send()` starts a fresh turn. Awaiting settlement before re-dispatch
|
||||
is what avoids an orphaned `pending` row and the same-`(topic, model)` race a
|
||||
synchronous abort+restart would create.
|
||||
|
||||
Agent-session topics are the exception: they are **not** aborted. The follow-up
|
||||
is enqueued on the session's `pendingTurns` and the running turn is interrupted
|
||||
between tool calls; `send()` only upserts the new subscriber. See
|
||||
[Agent Session Runtime → Live follow-up](./agent-session-runtime.md#live-follow-up).
|
||||
|
||||
## End-to-end flows
|
||||
|
||||
One row per flow. The two with dedicated docs are cross-linked rather than
|
||||
duplicated; the rest are stream-manager-specific.
|
||||
|
||||
| Flow | Trigger | Mechanism | Terminal / result |
|
||||
|---|---|---|---|
|
||||
| Submit (standard) | `Ai_Stream_Open` | `dispatchStreamRequest` → `prepareDispatch` (persist user msg, reserve placeholders, build listeners + models) → `manager.send` → N × `runExecutionLoop` | `Ai_StreamDone`; `PersistenceListener.persistAssistant`; chat lifecycle `scheduleCleanup(30 s)` |
|
||||
| Steering — chat resubmit | `Ai_Stream_Open` on a live chat topic | `dispatch` → `manager.abortAndAwait` (abort execs, await settle as `paused`, evict) → `manager.send` starts a fresh turn | prior partial persisted as **`paused`**; new turn streams — see [Steering](#steering) |
|
||||
| Agent-session follow-up | `Ai_Stream_Open` on a live `agent-session:*` topic | provider persists the user row, `enqueueUserMessage` → `pendingTurns`, interrupt-when-safe; `manager.send` upserts the subscriber → `{ mode: 'injected' }` | next turn starts from `pendingTurns` — see [Agent Session Runtime](./agent-session-runtime.md#live-follow-up) |
|
||||
| Tool-approval pause+resume | approval-request chunk → `awaiting-approval` | decision via `Ai_ToolApproval_Respond`; Claude-Agent unblocks `canUseTool`, MCP dispatches `continue-conversation` | card clears when the resumed stream broadcasts `pending` — see [Tool Approval](./tool-approval.md) |
|
||||
| Reconnect | `Ai_Stream_Attach` on mount | `manager.attach`: `not-found` / streaming (register listener + compact replay) / done-paused (`finalMessage(s)`) / error | live chunks resume, or the final row is returned |
|
||||
| Abort — user stop | `Ai_Stream_Abort` | per exec: `abortController.abort` → loop `signal` aborts → broadcast reader `cancel` → read loop `done` | partial persisted as **`paused`**; topic status → `aborted` (or `awaiting-approval` if an exec had it set) |
|
||||
| Abort — no subscribers | last `WebContentsListener` dies + `backgroundMode === 'abort'` | `onChunk` prunes dead listeners; `listeners.size === 0` → auto `abort(topicId, 'no-subscribers')` | partial persisted as **`paused`** — never silently `success` or leaked |
|
||||
| Multi-window | window B opens a live topic | B sends `Ai_Stream_Attach` → compact replay + its own `WebContentsListener`; each chunk fans out to A and B | both windows render the same chunks in sync |
|
||||
| Channel / Agent | `AiStreamManager.send` in-process (no IPC) | scenario differs only by listener composition (table below) | per-listener effect |
|
||||
|
||||
**Topic status needs no `attach`.** Observers that only care "is this topic
|
||||
live?" (sidebar loading indicators, topic-list status dots) don't register a
|
||||
`WebContentsListener`. Every status transition writes the SharedCache key
|
||||
`topic.stream.statuses.${topicId}`; observers read it via `useSharedCache`
|
||||
directly. `Ai_Stream_Attach` is only needed when a window wants live chunks.
|
||||
|
||||
### Channel / Agent listener composition
|
||||
|
||||
Channel adapters and the agent scheduler call `AiStreamManager.send`
|
||||
directly inside Main — no IPC. The scenario differences are entirely in the
|
||||
listener composition:
|
||||
|
||||
| Scenario | Listeners | Effect |
|
||||
|---|---|---|
|
||||
| Renderer user message | `WebContentsListener` + `PersistenceListener` | live UI + persist |
|
||||
| Channel bot reply | `ChannelAdapterListener` + agent-session persistence listener | IM send + agents DB |
|
||||
| Channel + user both watching | above + `WebContentsListener(B)` | parallel fan-out |
|
||||
| API server SSE | `SseListener` + `PersistenceListener` | SSE push + persist |
|
||||
| Translate | `WebContentsListener` + `PersistenceListener(TranslationBackend)` | live overlay + writes `data-translation` part on success |
|
||||
|
||||
## IPC contract
|
||||
|
||||
### Request channels (Renderer → Main)
|
||||
|
||||
| Channel | Payload | Response | Semantics |
|
||||
|---|---|---|---|
|
||||
| `Ai_Stream_Open` | `AiStreamOpenRequest` (`submit-message` \| `regenerate-message`) | `{ mode, executionIds?, userMessageId?, placeholderIds? }` | Open / inject; provider routes by topicId |
|
||||
| `Ai_Stream_Attach` | `{ topicId }` | `AiStreamAttachResponse` | Subscribe; returns compact replay when streaming |
|
||||
| `Ai_Stream_Detach` | `{ topicId }` | void | Unsubscribe (stream continues) |
|
||||
| `Ai_Stream_Abort` | `{ topicId }` | void | Stop current generation |
|
||||
|
||||
> Topic status snapshots need no dedicated IPC: a new window pulls every
|
||||
> `topic.stream.statuses.${topicId}` entry via `Cache_GetAllShared` on
|
||||
> mount, and `useSharedCache` subscribes by topicId.
|
||||
|
||||
### Push channels (Main → Renderer)
|
||||
|
||||
| Channel | Payload | Notes |
|
||||
|---|---|---|
|
||||
| `Ai_StreamChunk` | `{ topicId, executionId?, chunk }` | Multi-model carries `executionId`; **only sent to attached windows** |
|
||||
| `Ai_StreamDone` | `{ topicId, executionId?, status, isTopicDone }` | `status ∈ { 'success', 'paused' }` — natural completion vs user abort; **only sent to attached windows** |
|
||||
| `Ai_StreamError` | `{ topicId, executionId?, isTopicDone, error }` | `SerializedError`; **only sent to attached windows** |
|
||||
|
||||
Topic-level status transitions are NOT a bespoke IPC — they live in the
|
||||
SharedCache key `topic.stream.statuses.${topicId}` (Main `setShared` →
|
||||
built-in `Cache_Sync` broadcast). The entry shape is
|
||||
`TopicStatusSnapshotEntry`:
|
||||
|
||||
```typescript
|
||||
{
|
||||
status: 'pending' | 'streaming' | 'done' | 'aborted' | 'awaiting-approval' | 'error'
|
||||
activeExecutions: ActiveExecution[] // execs currently `streaming`
|
||||
awaitingApprovalAnchors: ActiveExecution[] // execs with awaitingApproval = true
|
||||
lastCompletedAt?: number // bumped only on `done`; the fulfilled-badge read-receipt gate
|
||||
}
|
||||
```
|
||||
|
||||
`pending` doubles as the "new stream just created" signal — the old
|
||||
`Ai_StreamStarted` IPC is gone. Grace-period cleanup does NOT clear the
|
||||
entry — terminal values (`done` / `aborted` / `error`) stay so renderer
|
||||
consumers (DB-refresh trigger, awaiting-approval indicators, sidebar
|
||||
badges) can observe them. The badge "should I show this?" gate is a
|
||||
read-receipt: `entry.lastCompletedAt` (authoritative, bumped only on
|
||||
`done`) compared against `topic.stream.last_seen_completion.${topicId}`
|
||||
(cross-window shared cache, written by the renderer when the user
|
||||
acknowledges).
|
||||
|
||||
**All traffic is keyed by topicId**; multi-model uses `executionId` to
|
||||
demux chunks per model.
|
||||
|
||||
**Topic status vs message status.** Don't conflate:
|
||||
|
||||
- **Topic stream status** (SharedCache `topic.stream.statuses.${topicId}`):
|
||||
one entry per topic, source of truth is `ActiveStream.status`, valid
|
||||
only while the `ActiveStream` exists (+ grace period).
|
||||
- **Assistant message status** (`AssistantMessageStatus`: `PENDING` /
|
||||
`PROCESSING` / `SUCCESS` / `ERROR`): one per assistant message,
|
||||
persisted in SQLite, written by `PersistenceListener.onDone/onError`.
|
||||
In multi-model, a single topic-level transition corresponds to N
|
||||
separate message rows.
|
||||
|
||||
## ChatContextProvider — per-topicId namespace dispatch
|
||||
|
||||
`Ai_Stream_Open` is handled in Main by `dispatchStreamRequest`
|
||||
(`context/dispatch.ts`):
|
||||
|
||||
```
|
||||
dispatchStreamRequest(manager, subscriber, req)
|
||||
→ provider = providers.find(p => p.canHandle(req.topicId))
|
||||
→ prepared = await provider.prepareDispatch(subscriber, req, { hasLiveStream })
|
||||
→ result = manager.send(prepared) // ← the only manager.send call
|
||||
→ return { mode, executionIds?, userMessageId?, placeholderIds? }
|
||||
```
|
||||
|
||||
Providers only "prepare" — they never call `manager.send` directly. Two
|
||||
benefits:
|
||||
|
||||
- Provider unit tests assert on `PreparedDispatch` shape without mocking
|
||||
the manager.
|
||||
- The restart / start / multi-model fan-out routing lives in exactly one
|
||||
place.
|
||||
|
||||
### Provider interface
|
||||
|
||||
```typescript
|
||||
interface ChatContextProvider {
|
||||
readonly name: string
|
||||
canHandle(topicId: string): boolean
|
||||
prepareDispatch(
|
||||
subscriber: StreamListener,
|
||||
req: MainDispatchRequest,
|
||||
ctx: { hasLiveStream: boolean }
|
||||
): Promise<PreparedDispatch>
|
||||
}
|
||||
|
||||
interface PreparedDispatch {
|
||||
topicId: string
|
||||
models: ReadonlyArray<{ modelId: UniqueModelId; request: AiStreamRequest; rootSpan?: Span }>
|
||||
listeners: StreamListener[] // subscriber + per-execution PersistenceListener(s)
|
||||
userMessage?: Message
|
||||
userMessageId?: string
|
||||
siblingsGroupId?: number
|
||||
isMultiModel: boolean
|
||||
lifecycle?: StreamLifecycle
|
||||
}
|
||||
|
||||
// dispatch.ts also accepts a Main-internal `continue-conversation`
|
||||
// variant synthesised by the tool-approval IPC handler — not exposed
|
||||
// over the renderer ↔ main contract.
|
||||
type MainDispatchRequest = AiStreamOpenRequest | MainContinueConversationRequest
|
||||
```
|
||||
|
||||
### Built-in providers
|
||||
|
||||
| Provider | `canHandle` | Data layer | User message | Assistant message |
|
||||
|---|---|---|---|---|
|
||||
| **AgentChatContextProvider** | `topicId.startsWith('agent-session:')` | `agentMessageRepository` | written upfront | runtime provides `PersistenceListener(AgentSessionMessageBackend)` |
|
||||
| **TemporaryChatContextProvider** | `temporaryChatService.hasTopic(topicId)` | `TemporaryChatService` (in-memory) | appended upfront | `PersistenceListener(TemporaryChatBackend)` appends on done |
|
||||
| **PersistentChatContextProvider** | `true` (catch-all) | `messageService` + SQLite | transactional create | `PersistenceListener(MessageServiceBackend)` updates pending on done |
|
||||
|
||||
Order: Agent → Temporary → Persistent (first `canHandle === true`
|
||||
wins).
|
||||
|
||||
### Persistence path comparison
|
||||
|
||||
| | Persistent | Temporary | Agent |
|
||||
|---|---|---|---|
|
||||
| User message timing | before stream (tree node) | before stream (append) | before stream (agents DB) |
|
||||
| Assistant placeholder | created pending before stream | none | created pending before stream (atomic with user msg) |
|
||||
| Terminal write | `update` placeholder | `append` new row | `update` placeholder (`persistAssistant`) |
|
||||
| Backend | `MessageServiceBackend` | `TemporaryChatBackend` | `AgentSessionMessageBackend` |
|
||||
| Multi-model | ✓ | ✗ (single-model) | ✗ (single-model) |
|
||||
| Regenerate | ✓ | ✗ | ✗ |
|
||||
|
||||
### One PersistenceListener across all topic kinds
|
||||
|
||||
Persistent / Temporary / Agent / Translation all share the same
|
||||
`PersistenceListener` class — only the injected `PersistenceBackend`
|
||||
differs. The observer protocol (`modelId` filter, error part folding,
|
||||
skip-when-no-finalMessage, swallow errors) is implemented once.
|
||||
|
||||
## AiService integration
|
||||
|
||||
`AiService` is a lifecycle service:
|
||||
|
||||
- **Streaming.** `streamText(request)` returns
|
||||
`Promise<ReadableStream<UIMessageChunk>>`, consumed by
|
||||
`AiStreamManager.runExecutionLoop`.
|
||||
- **Non-streaming IPC gateway.** `generateText` / `checkModel` /
|
||||
`embedMany` / `generateImage` / `listModels`, registered as IPC
|
||||
handlers in `onInit`.
|
||||
|
||||
`AiStreamManager` calls `await application.get('AiService').streamText(...)`.
|
||||
Pre-stream errors (provider / model resolution, agent param build)
|
||||
reject the returned Promise; mid-stream errors come through the returned
|
||||
stream's error path — the two error paths never overlap.
|
||||
|
||||
## Grace period & reconnect
|
||||
|
||||
After a stream terminates, `ActiveStream` stays in memory for 30 s
|
||||
(`config.gracePeriodMs`). During that window a returning user can
|
||||
`attach` and pull `finalMessage` without a DB read. After expiry the
|
||||
entry is evicted; subsequent `attach` returns `not-found` and the
|
||||
renderer reads from the DB through `useQuery` (PersistenceListener has
|
||||
already written by then).
|
||||
|
||||
If the user stops and immediately retries on the same topic, `send`
|
||||
takes the start branch: `evictStream` first clears the grace-period
|
||||
remnant (cancels the cleanup timer and drops the entry from
|
||||
`activeStreams`), then the new stream is created — the old never blocks
|
||||
the new.
|
||||
|
||||
## Edge case cheat sheet
|
||||
|
||||
| Case | Handling |
|
||||
|---|---|
|
||||
| User sends again on the same topic mid-stream (chat) | `dispatch` calls `abortAndAwait` (prior turn persists as `paused`), then `send` starts a fresh turn |
|
||||
| Retry immediately after stream ends | `send` takes start; `evictStream` clears the grace-period entry first |
|
||||
| Window closes mid-stream | Next broadcast sees `WebContentsListener.isAlive() === false` and removes it; `PersistenceListener` doesn't depend on a window |
|
||||
| All windows closed + `backgroundMode='continue'` | Stream continues; `PersistenceListener` persists when done |
|
||||
| All windows closed + `backgroundMode='abort'` | `onChunk` finds `stream.listeners.size === 0` → `abort(topicId, 'no-subscribers')`; partial persisted as `paused` |
|
||||
| Multi-window on same topic | Each window has its own `WebContentsListener`; chunks fan out to all alive listeners |
|
||||
| Same window re-attaches | Listener id is stable (`wc:${wc.id}:${topicId}`); `addListener` upserts by id |
|
||||
| Attach mid-stream | `attach` returns compact replay per execution (each buffer compacted independently); observer fills in the gap |
|
||||
| Ring buffer overflow | At `maxBufferChunks` the oldest chunk drops and `droppedChunks++`; subsequent attach logs the total dropped — replay is no longer lossless |
|
||||
| Multi-model + resubmit | `abortAndAwait` settles all executions before restart; no orphaned `pending` row |
|
||||
| Stream emits `tool-approval-request` | `exec.awaitingApproval = true`; on stream end the topic surfaces `awaiting-approval` via the shared cache |
|
||||
| Main process restart | `activeStreams` clears; in-flight streams are lost; the renderer re-reads from the DB |
|
||||
|
||||
## Design notes
|
||||
|
||||
### Testing strategy
|
||||
|
||||
- **Manager tests.** `new AiStreamManager({ maxBufferChunks: 3 })` via
|
||||
the optional config arg; state assertions go through `mgr.inspect(topicId)`;
|
||||
listener upsert / abort / backgroundMode are tested via behaviour
|
||||
(drive a chunk, assert which listeners received it).
|
||||
- **Provider tests.** Assert on the returned `PreparedDispatch` shape
|
||||
directly — no manager mock.
|
||||
- **PersistenceListener tests.** `TemporaryChatBackend` as the test
|
||||
vehicle covers the observer protocol once for every backend.
|
||||
- All internal state has a public inspection API; production and tests
|
||||
share the same contract.
|
||||
75
docs/references/ai/tool-approval.md
Normal file
75
docs/references/ai/tool-approval.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Tool Approval
|
||||
|
||||
## Model
|
||||
|
||||
Main is the single writer of approval state. The renderer surfaces an
|
||||
`approval-requested` ToolUIPart, takes the user's decision, and posts it
|
||||
to Main. Main applies the decision to the DB-authoritative anchor parts,
|
||||
persists, and resumes the stream.
|
||||
|
||||
## End-to-end flow
|
||||
|
||||
1. **Tool needs approval** — at `execute` time, the wrapper checks
|
||||
`tool.needsApproval` and the assistant's auto-approve policy. If
|
||||
approval is required, the wrapper writes an `approval-requested` part
|
||||
and resolves the tool's promise into a held state (Claude-Agent: holds
|
||||
`canUseTool`; MCP: stream pauses on the approval part).
|
||||
|
||||
2. **Stream pauses** — `AiStreamManager` transitions the topic to
|
||||
`awaiting-approval`. The `topic.stream.statuses.<topicId>` shared-cache
|
||||
entry carries the status; every renderer window reading that key sees
|
||||
the pause atomically.
|
||||
|
||||
3. **User decides** — the approval card renders from the part. On click,
|
||||
`useToolApprovalBridge` (`src/renderer/hooks/useToolApprovalBridge.ts`)
|
||||
calls `window.api.ai.toolApproval.respond(...)` with `approvalId`,
|
||||
`approved`, optional `reason` / `updatedInput`, `topicId`, `anchorId`.
|
||||
|
||||
4. **Main applies** — `AiService`'s `Ai_ToolApproval_Respond` handler
|
||||
branches on transport **before** touching the DB:
|
||||
- **Claude-Agent fast-path** (`AiService.ts:191-197`): hands the
|
||||
decision to `AgentSessionRuntimeService.respondToolApproval`, which
|
||||
resolves the live `canUseTool` promise so the existing stream
|
||||
proceeds. When a live registry entry handles it, the handler
|
||||
**early-returns — no DB read happens** (and `topicId` / `anchorId`
|
||||
are not required).
|
||||
- **MCP path** (reached only when no live entry matched; requires
|
||||
`topicId` + `anchorId`): reads the anchor message's current `parts`
|
||||
from DB, applies the decision, and **writes only when the target
|
||||
`approval-requested` part is present on the DB row** — guarding the
|
||||
overlay-only case (approval received before the part has persisted).
|
||||
When all approvals on the turn are decided it dispatches a synthetic
|
||||
`continue-conversation` request through `dispatchStreamRequest`; the
|
||||
provider applies the decision when it reads parts.
|
||||
|
||||
5. **Awaiting-approval clears** — the moment the continue stream
|
||||
broadcasts `pending`, the shared-cache entry flips back. Every window
|
||||
sees the approval card disappear in the same tick.
|
||||
|
||||
## Persistent decisions
|
||||
|
||||
`useToolApproval` (`src/renderer/pages/home/Messages/Tools/hooks/useToolApproval.ts`)
|
||||
exposes an `autoApprove` action **only for MCP tools** — when an `mcpTool`
|
||||
descriptor is passed. It persists the opt-out by PATCHing the server's
|
||||
`disabledAutoApproveTools`, so the MCP settings page reflects it and
|
||||
subsequent calls of that tool skip the approval card. There is no generic
|
||||
per-tool default for non-MCP (e.g. Claude-Agent) tools.
|
||||
|
||||
## Why this design
|
||||
|
||||
- **No renderer writes** — the renderer cannot PATCH approval state. If
|
||||
it did, it would race Main's authoritative re-read and cause the
|
||||
approval card to reappear on every click.
|
||||
- **Cross-window consistency** — the shared-cache `awaiting-approval`
|
||||
status is the single source of truth for "this topic is paused".
|
||||
- **Overlay/persist gap** — the renderer sometimes sees the
|
||||
`approval-requested` part via overlay before it lands in the DB row.
|
||||
Writing unconditionally would clobber the (concurrent) Main-side
|
||||
persistence; the conditional write + continue-dispatch covers that case.
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Main IPC handler: `src/main/ai/AiService.ts` (`Ai_ToolApproval_Respond`)
|
||||
- Renderer bridge: `src/renderer/hooks/useToolApprovalBridge.ts`
|
||||
- Persistent decisions: `src/renderer/pages/home/Messages/Tools/hooks/useToolApproval.ts`
|
||||
- Status broadcast: [Stream Manager](./stream-manager.md)
|
||||
142
docs/references/ai/tool-registry.md
Normal file
142
docs/references/ai/tool-registry.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# Tool Registry
|
||||
|
||||
## Model
|
||||
|
||||
```ts
|
||||
interface ToolEntry {
|
||||
name: string // wire-name, what the LLM emits in tool_calls
|
||||
namespace: string // grouping for `tool_search` (web, kb, mcp:<id>, meta)
|
||||
description: string // one-line summary for `tool_search`
|
||||
defer: 'never' | 'always' | 'auto'
|
||||
tool: Tool // AI SDK Tool (schema + execute + needsApproval + toModelOutput)
|
||||
applies?(scope): boolean
|
||||
}
|
||||
```
|
||||
|
||||
`registry` (`src/main/ai/tools/adapters/aiSdk/registry.ts`) is a
|
||||
process-wide singleton. Tool files register at module-import time; the
|
||||
registry is read at request time by `buildAgentParams`. The Claude Code
|
||||
runtime has a *separate* tool system — `tools/adapters/claudeCode/agentTools.ts`
|
||||
builds its descriptors from MCP servers and built-in descriptors directly;
|
||||
it does not consume this aiSdk `ToolRegistry`.
|
||||
|
||||
Tests construct their own `new ToolRegistry()` to avoid singleton pollution.
|
||||
|
||||
## Wire-name convention
|
||||
|
||||
Double underscore is the segment separator (so internal single `_` stays
|
||||
unambiguous):
|
||||
|
||||
| Source | Name pattern | Example |
|
||||
|---|---|---|
|
||||
| Built-in | `<namespace>__<verb>` | `web__search`, `kb__search` |
|
||||
| MCP | `mcp__<camelCase(server)>__<camelCase(tool)>` | `mcp__gmail__sendMessage` |
|
||||
| Meta | `tool_<verb>` | `tool_search`, `tool_invoke`, `tool_inspect` (`tool_exec` is defined but not injected — see below) |
|
||||
|
||||
## Built-in tools
|
||||
|
||||
`src/main/ai/tools/adapters/aiSdk/builtin/` registers **four** entries:
|
||||
|
||||
- `web__search` (`WebSearchTool.ts` → `createWebSearchToolEntry`) — namespace
|
||||
`web`. Talks to the configured web-search provider via the
|
||||
renderer-shared search service.
|
||||
- `web__fetch` (`WebSearchTool.ts` → `createWebFetchToolEntry`) — namespace
|
||||
`web`. Fetches a URL's content.
|
||||
- `kb__search` (`KnowledgeSearchTool.ts`) — semantic search over the active
|
||||
knowledge base.
|
||||
- `kb__list` (`KnowledgeListTool.ts`) — enumerate available knowledge bases /
|
||||
documents.
|
||||
|
||||
Registration happens in `builtin/index.ts` (`registerBuiltinTools`). Each
|
||||
tool's `applies` gates on the relevant `assistant.settings.*` flag (e.g.
|
||||
`enableWebSearch`).
|
||||
|
||||
## MCP tools
|
||||
|
||||
`src/main/ai/tools/adapters/aiSdk/mcp/`:
|
||||
|
||||
- `resolveAssistantMcpToolIds` — assistant's enabled MCP servers + per-tool
|
||||
disable list → set of tool ids.
|
||||
- `mcpTools.syncMcpToolsToRegistry({ selectedToolIds })` — calls
|
||||
`listTools` on each MCP server that owns at least one selected tool,
|
||||
registers each as a `ToolEntry` whose `tool.execute` proxies through
|
||||
the MCP transport. **Scope:** only servers owning a selected tool are
|
||||
hit — avoids paying the per-server round-trip when only one MCP tool
|
||||
is in use for this request.
|
||||
|
||||
The sync is idempotent; a stale entry is overwritten on the next sync.
|
||||
|
||||
## Meta-tools
|
||||
|
||||
`src/main/ai/tools/adapters/aiSdk/meta/` defines four tools that turn the
|
||||
registry into a search-then-call interface for the model. Only the first
|
||||
three are injected:
|
||||
|
||||
| Tool | Injected? | Use |
|
||||
|---|---|---|
|
||||
| `tool_search` | yes | Browse the deferred pool by namespace + query, returns brief descriptions |
|
||||
| `tool_inspect` | yes | Emit a JSDoc stub for one tool — enough to call it correctly |
|
||||
| `tool_invoke` | yes | Invoke any registry tool by name with a JSON arg blob |
|
||||
| `tool_exec` | **no** | Sandboxed JS exec with the full registry as a global API (`meta/exec/runtime.ts`, `meta/exec/worker.ts`) — defined but intentionally not injected |
|
||||
|
||||
The injected three are added to the tool set by `applyDeferExposition` when
|
||||
(and only when) the request actually defers tools. See below.
|
||||
|
||||
## Defer exposition
|
||||
|
||||
`src/main/ai/tools/adapters/aiSdk/exposition/`:
|
||||
|
||||
- `shouldDefer(entries, contextWindow)` — returns the set of names to
|
||||
defer. Two gates above the simple threshold:
|
||||
- **MIN_AUTO_DEFER_COUNT** — the auto pool must be large enough that
|
||||
search-then-invoke beats inlining.
|
||||
- **META_TOOLS_OVERHEAD_TOKENS** — estimated savings must exceed the
|
||||
meta-tools' static prompt cost. Without these gates, small tool sets
|
||||
+ small-context models trigger defer and pay net-negative tokens.
|
||||
|
||||
- `applyDeferExposition(tools, registry, contextWindow)` — strips the
|
||||
deferred names out of `tools`, injects `tool_search` / `tool_inspect` /
|
||||
`tool_invoke`, and returns the entries the system-prompt's
|
||||
`<DEFERRED_TOOLS>` section needs to enumerate (so the model knows what
|
||||
namespaces exist).
|
||||
|
||||
**Approval-gated tools are never deferred.** A force-prompt MCP tool is registered
|
||||
with `defer: 'never'` — `mcp/mcpTools.ts` reads `isMcpToolForcePromptBySource` once
|
||||
to drive both `defer` and `needsApproval` — so it stays inline and the SDK's native
|
||||
approval gate fires on it. Deferring it would drop it from the SDK tool-set, so the
|
||||
gate would never fire and it would be reachable only through `tool_invoke` with no
|
||||
approval card. As a runtime backstop the `tool_invoke` / `tool_exec` meta-tools also
|
||||
call `isApprovalGated` at execution time and refuse a gated tool (covering the
|
||||
`registry.getByName(any-name)` vector), steering the model to call it inline. See
|
||||
[Tool Approval](./tool-approval.md).
|
||||
|
||||
`tool_exec` is **not injected** by `applyDeferExposition` — there is no
|
||||
`metaTools.exec` flag. The injection site (`applyDeferExposition.ts:50-53`)
|
||||
deliberately leaves it out: its `worker_threads` + `new Function` sandbox
|
||||
runs model-authored code with full Node privileges, a privilege-escalation
|
||||
surface vs the renderer's prior restrictions. It is meant to be re-enabled
|
||||
behind an explicit Preference key once there is a concrete need.
|
||||
|
||||
## `applies` and tool-call repair
|
||||
|
||||
- `applies(scope: ToolApplyScope)` — per-entry predicate consulted at
|
||||
`registry.selectActive`. Throws are caught and treated as "inactive"
|
||||
with a warning log.
|
||||
- `createAiRepair(...)` (`tools/adapters/aiSdk/repair.ts`) — passed to AI SDK as
|
||||
`experimental_repairToolCall`. When the model emits **malformed args**
|
||||
(`InvalidToolInputError`), the repair function gets one chance to fix it via a
|
||||
follow-up LLM call. Other failures (e.g. an unknown tool name) are
|
||||
returned unrepaired.
|
||||
|
||||
## Where to read more
|
||||
|
||||
- Code: `src/main/ai/tools/adapters/aiSdk/` (Claude Code adapter:
|
||||
`src/main/ai/tools/adapters/claudeCode/`)
|
||||
- Tests: `tools/adapters/aiSdk/__tests__/`,
|
||||
`tools/adapters/aiSdk/builtin/__tests__/`,
|
||||
`tools/adapters/aiSdk/exposition/__tests__/`,
|
||||
`tools/adapters/aiSdk/mcp/__tests__/`,
|
||||
`tools/adapters/aiSdk/meta/__tests__/`
|
||||
- Defer rationale, gate thresholds:
|
||||
`tools/adapters/aiSdk/exposition/shouldDefer.ts` (header doc + tests)
|
||||
- Approval flow: [Tool Approval](./tool-approval.md)
|
||||
@@ -130,7 +130,7 @@ User Message
|
||||
└── Response Pipeline ──→ Message blocks (text, code, image, tool-call)
|
||||
```
|
||||
|
||||
See [AI Core Architecture](./ai-core-architecture.md) for the complete data flow.
|
||||
See [AI Reference](./ai/README.md) for the complete data flow.
|
||||
|
||||
## Monorepo Structure
|
||||
|
||||
@@ -142,7 +142,7 @@ cherry-studio
|
||||
│ │ ├── data/ # Data layer (DB, Cache, Preference, DataApi)
|
||||
│ │ ├── services/ # 27 lifecycle-managed services
|
||||
│ │ ├── knowledge/ # RAG / knowledge base
|
||||
│ │ ├── mcpServers/ # Built-in MCP servers
|
||||
│ │ ├── ai/mcp/servers/ # Built-in MCP servers
|
||||
│ │ ├── apiServer/ # Local REST API (Express)
|
||||
│ │ └── integration/ # External integrations
|
||||
│ │
|
||||
@@ -179,8 +179,8 @@ cherry-studio
|
||||
|-----------|----------|---------------|
|
||||
| Service Lifecycle | `src/main/core/lifecycle/` | [Lifecycle Reference](./lifecycle/README.md) |
|
||||
| Data Layer | `src/main/data/` | [Data Reference](./data/README.md) |
|
||||
| AI Core | `src/renderer/aiCore/` | [AI Core Architecture](./ai-core-architecture.md) |
|
||||
| MCP (Tool Use) | `src/main/services/mcp/` | — |
|
||||
| AI Core | `src/main/ai/` | [AI Reference](./ai/README.md) |
|
||||
| MCP (Tool Use) | `src/main/ai/mcp/` | — |
|
||||
| Knowledge (RAG) | `src/main/knowledge/` | [KnowledgeService](./knowledge/knowledge-service.md) |
|
||||
| Message System | `src/renderer/store/` | [Message System](./messaging/message-system.md) |
|
||||
| CherryClaw (Agent) | `src/main/services/agents/` | [CherryClaw Overview](./cherryclaw/overview.md) |
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Claw MCP Server
|
||||
|
||||
The Claw MCP server is a built-in MCP (Model Context Protocol) server automatically injected into every CherryClaw session. It provides four self-management tools for the agent: `cron` (task scheduling), `notify` (notifications), `skills` (skill management), and `memory` (memory management).
|
||||
The Claw MCP server is a built-in MCP (Model Context Protocol) server automatically injected into every CherryClaw (Soul Mode) session. It provides three self-management tools for the agent: `cron` (task scheduling), `notify` (notifications), and `config` (agent/channel self-configuration). Skill and memory management used to live here too but were extracted into their own standalone MCP servers (see [Related servers](#related-servers-formerly-claw-tools)).
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -10,7 +10,7 @@ CherryClawService.invoke()
|
||||
→ Inject as in-memory MCP server:
|
||||
_internalMcpServers = { claw: { type: 'inmem', instance: clawServer.mcpServer } }
|
||||
→ ClaudeCodeService merges into SDK options.mcpServers
|
||||
→ SDK auto-discovers tools: mcp__claw__cron, mcp__claw__notify, mcp__claw__skills, mcp__claw__memory
|
||||
→ SDK auto-discovers tools: mcp__claw__cron, mcp__claw__notify, mcp__claw__config
|
||||
```
|
||||
|
||||
ClawServer uses the `@modelcontextprotocol/sdk` `McpServer` class, running in memory mode (no HTTP transport). A new instance is created per CherryClaw session invocation, bound to the current agent's ID.
|
||||
@@ -83,90 +83,44 @@ Returns an informational message (not an error) if no notification channels are
|
||||
|
||||
---
|
||||
|
||||
## skills Tool
|
||||
## config Tool
|
||||
|
||||
Manage Claude skills in the agent workspace. Supports searching from the marketplace, installing, uninstalling, and listing installed skills.
|
||||
Inspect and manage the agent's own configuration — identity, model, and IM channel connections — and drive the onboarding ("bootstrap") ritual.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `action` | string | Yes | One of the actions below |
|
||||
| `type` | string | For `add_channel` | Channel adapter type: `telegram` / `feishu` / `qq` / `wechat` / `discord` / `slack` |
|
||||
| `name` | string | For `rename` / `add_channel` | New display name (`rename`) or human-readable channel name (`add_channel`) |
|
||||
| `channel_id` | string | For `update_channel` / `remove_channel` / `reconnect_channel` | Target channel id |
|
||||
| `config` | object | For `add_channel` | Adapter-specific configuration (optional for `update_channel`) |
|
||||
| `enabled` | boolean | No | Enable/disable the channel (defaults to true) |
|
||||
|
||||
### Actions
|
||||
|
||||
#### `search` — Search Skills
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `query` | string | Yes | Search keywords |
|
||||
|
||||
Queries the public marketplace API (`claude-plugins.dev/api/skills`), returns matching skills with `name`, `description`, `author`, `identifier` (for installation), and `installs` count. Hyphens and underscores in search terms are replaced with spaces to improve matching.
|
||||
|
||||
#### `install` — Install Skill
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `identifier` | string | Yes | Marketplace skill identifier, format `owner/repo/skill-name` |
|
||||
|
||||
Constructs `marketplace:skill:{identifier}` path internally, delegates to `PluginService.install()`.
|
||||
|
||||
#### `remove` — Uninstall Skill
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | string | Yes | Skill folder name (from list results) |
|
||||
|
||||
Delegates to `PluginService.uninstall()`.
|
||||
|
||||
#### `list` — List Installed Skills
|
||||
|
||||
No parameters. Returns all installed skills for the current agent, including `name`, `folder`, and `description`.
|
||||
| Action | Description |
|
||||
|---|---|
|
||||
| `status` | Current channels, model, and supported adapter types |
|
||||
| `rename` | Change the agent's display name |
|
||||
| `add_channel` / `update_channel` / `remove_channel` | Manage IM channel connections |
|
||||
| `reconnect_channel` | Re-scan a QR code for a WeChat/Feishu channel (e.g. expired session or failed initial setup) |
|
||||
| `complete_bootstrap` | Mark the onboarding ritual as done |
|
||||
| `reset_bootstrap` | Re-run onboarding in the next session |
|
||||
|
||||
---
|
||||
|
||||
## memory Tool
|
||||
## Related servers (formerly claw tools)
|
||||
|
||||
Manage persistent cross-session memory. This is the write interface for CherryClaw's memory system (reading is done via inline content in the system prompt).
|
||||
Skill and memory management were extracted out of claw into their own standalone MCP servers:
|
||||
|
||||
### Design Principle
|
||||
|
||||
The tool description encodes the memory decision logic:
|
||||
|
||||
> Before writing to FACT.md, ask yourself: will this information still matter in 6 months? If not, use append instead.
|
||||
|
||||
### Actions
|
||||
|
||||
#### `update` — Update FACT.md
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| Capability | Server | Tool | File |
|
||||
|---|---|---|---|
|
||||
| `content` | string | Yes | Complete markdown content of FACT.md |
|
||||
| Skills (search / install / uninstall / list) | `skills` | `mcp__skills__skills` | `src/main/ai/mcp/servers/skills.ts` |
|
||||
| Persistent memory (update / append / search) | `agent-memory` | `mcp__agent-memory__memory` | `src/main/ai/mcp/servers/workspaceMemory.ts` |
|
||||
|
||||
Atomic write: writes to a temp file first, then replaces via `rename`. Ensures no file corruption from mid-write crashes.
|
||||
|
||||
File path supports case-insensitive matching. The `memory/` directory is auto-created if it doesn't exist.
|
||||
|
||||
**Note**: This is a full overwrite, not an incremental edit. The agent needs to read existing content first, modify it, then write back the complete content.
|
||||
|
||||
#### `append` — Append Log Entry
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `text` | string | Yes | Log entry text |
|
||||
| `tags` | string[] | No | Tag list |
|
||||
|
||||
Appends a JSON line to `memory/JOURNAL.jsonl`:
|
||||
|
||||
```json
|
||||
{"ts":"2026-03-10T12:00:00.000Z","tags":["deploy","production"],"text":"Deployed v2.1 to production"}
|
||||
```
|
||||
|
||||
Timestamp is auto-generated. Suitable for one-off events, completed tasks, session summaries, and other short-term information.
|
||||
|
||||
#### `search` — Search Logs
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|---|---|---|---|
|
||||
| `query` | string | No | Case-insensitive substring match |
|
||||
| `tag` | string | No | Filter by tag |
|
||||
| `limit` | integer | No | Max results (default 20) |
|
||||
|
||||
Returns matching log entries in reverse chronological order. `query` and `tag` can be combined.
|
||||
> The CherryClaw system prompt and the workspace bootstrap reference memory as `mcp__agent-memory__memory` — **not** `mcp__claw__memory`.
|
||||
|
||||
---
|
||||
|
||||
@@ -178,6 +132,6 @@ All tool calls execute within an internal try-catch. On error, returns an `{ isE
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| `src/main/mcpServers/claw.ts` | ClawServer complete implementation (4 tools + helpers) |
|
||||
| `src/main/mcpServers/__tests__/claw.test.ts` | 37 unit tests |
|
||||
| `src/main/services/agents/services/cherryclaw/index.ts` | MCP server injection logic |
|
||||
| `src/main/ai/mcp/servers/claw.ts` | ClawServer implementation (`cron` / `notify` / `config` + helpers) |
|
||||
| `src/main/ai/mcp/servers/__tests__/claw.test.ts` | Unit tests |
|
||||
| `src/main/ai/runtime/claudeCode/settingsBuilder.ts` | `buildMcpServers` — injects the claw server in Soul Mode |
|
||||
|
||||
@@ -121,6 +121,6 @@ Both tables are associated with the agents table via foreign key cascades.
|
||||
| `src/main/services/agents/services/AgentServiceRegistry.ts` | Agent service registry |
|
||||
| `src/main/services/agents/services/TaskService.ts` | Task CRUD + scheduling calculation |
|
||||
| `src/main/services/agents/services/SchedulerService.ts` | Polling scheduler |
|
||||
| `src/main/mcpServers/claw.ts` | Claw MCP server |
|
||||
| `src/main/ai/mcp/servers/claw.ts` | Claw MCP server |
|
||||
| `src/main/services/agents/services/channels/` | Channel abstraction layer |
|
||||
| `src/main/services/agents/database/schema/tasks.schema.ts` | Task table schema |
|
||||
|
||||
@@ -18,7 +18,7 @@ Value type is inferred from the schema. Hooks pin the cache entry (refcounted)
|
||||
import { useCache, useSharedCache, usePersistCache } from '@data/hooks/useCache'
|
||||
|
||||
// Memory — single renderer
|
||||
const [generating, setGenerating] = useCache('chat.generating', false)
|
||||
const [generating, setGenerating] = useCache('chat.web_search.searching', false)
|
||||
|
||||
// Shared — all windows
|
||||
const [activeSearches, setActive] = useSharedCache('chat.web_search.active_searches')
|
||||
@@ -42,12 +42,12 @@ import { cacheService } from '@data/CacheService'
|
||||
|
||||
```typescript
|
||||
// Schema keys (Fixed or Template) — type-inferred
|
||||
cacheService.set('chat.generating', true)
|
||||
cacheService.set('chat.generating', true, 30_000) // with TTL (ms)
|
||||
cacheService.get('chat.generating') // boolean
|
||||
cacheService.has('chat.generating')
|
||||
cacheService.hasTTL('chat.generating')
|
||||
cacheService.delete('chat.generating')
|
||||
cacheService.set('chat.web_search.searching', true)
|
||||
cacheService.set('chat.web_search.searching', true, 30_000) // with TTL (ms)
|
||||
cacheService.get('chat.web_search.searching') // boolean
|
||||
cacheService.has('chat.web_search.searching')
|
||||
cacheService.hasTTL('chat.web_search.searching')
|
||||
cacheService.delete('chat.web_search.searching')
|
||||
|
||||
// Casual (Memory tier only, no schema match allowed)
|
||||
cacheService.setCasual<TopicCache>(`topic:${id}`, data, 30_000)
|
||||
|
||||
@@ -70,10 +70,10 @@ files:
|
||||
- "!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters,Makefile,*.Makefile}" # filter .node build files
|
||||
- "resources/**/*" # include all files in resources, and unpack them from asar for runtime access
|
||||
asarUnpack:
|
||||
- out/proxy/**
|
||||
- resources/**
|
||||
- "**/*.{metal,exp,lib}"
|
||||
- "node_modules/@img/sharp-*/**"
|
||||
- "node_modules/@anthropic-ai/claude-agent-sdk-*/**" # native CLI binary must live on disk to be spawned
|
||||
extraResources:
|
||||
- from: "migrations/sqlite-drizzle"
|
||||
to: "migrations/sqlite-drizzle"
|
||||
|
||||
@@ -8,7 +8,6 @@ import { visualizer } from 'rollup-plugin-visualizer'
|
||||
// assert not supported by biome
|
||||
// import pkg from './package.json' assert { type: 'json' }
|
||||
import pkg from './package.json'
|
||||
import { buildProxyBootstrapPlugin } from './scripts/buildProxyBootstrapPlugin'
|
||||
|
||||
const visualizerPlugin = (type: 'renderer' | 'main') => {
|
||||
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
|
||||
@@ -23,14 +22,7 @@ const mainExternalDependencies = Object.keys(pkg.dependencies).filter(
|
||||
|
||||
export default defineConfig({
|
||||
main: {
|
||||
plugins: [
|
||||
...visualizerPlugin('main'),
|
||||
buildProxyBootstrapPlugin({
|
||||
dependencies: Object.keys(pkg.dependencies),
|
||||
isProd,
|
||||
rootDir: __dirname
|
||||
})
|
||||
],
|
||||
plugins: [...visualizerPlugin('main')],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@main': resolve('src/main'),
|
||||
@@ -41,6 +33,10 @@ export default defineConfig({
|
||||
'@logger': resolve('src/main/core/logger/LoggerService'),
|
||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
|
||||
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
|
||||
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src'),
|
||||
'@vectorstores/libsql': resolve('packages/vectorstores/libsql/src/index.ts'),
|
||||
'@cherrystudio/provider-registry/node': resolve('packages/provider-registry/src/registry-loader'),
|
||||
'@cherrystudio/provider-registry': resolve('packages/provider-registry/src'),
|
||||
|
||||
95
migrations/sqlite-drizzle/0003_burly_tomorrow_man.sql
Normal file
95
migrations/sqlite-drizzle/0003_burly_tomorrow_man.sql
Normal file
@@ -0,0 +1,95 @@
|
||||
CREATE TABLE `agent_workspace` (
|
||||
`id` text PRIMARY KEY NOT NULL,
|
||||
`name` text NOT NULL,
|
||||
`path` text NOT NULL,
|
||||
`order_key` text NOT NULL,
|
||||
`created_at` integer NOT NULL,
|
||||
`updated_at` integer NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE UNIQUE INDEX `agent_workspace_path_unique_idx` ON `agent_workspace` (`path`);--> statement-breakpoint
|
||||
CREATE INDEX `agent_workspace_order_key_idx` ON `agent_workspace` (`order_key`);--> statement-breakpoint
|
||||
DROP TABLE `agent_task_run_log`;--> statement-breakpoint
|
||||
DROP TABLE `agent_task`;--> statement-breakpoint
|
||||
PRAGMA foreign_keys=OFF;--> statement-breakpoint
|
||||
CREATE TABLE `__new_agent_channel_task` (
|
||||
`channel_id` text NOT NULL,
|
||||
`task_id` text NOT NULL,
|
||||
PRIMARY KEY(`channel_id`, `task_id`),
|
||||
FOREIGN KEY (`channel_id`) REFERENCES `agent_channel`(`id`) ON UPDATE no action ON DELETE cascade,
|
||||
FOREIGN KEY (`task_id`) REFERENCES `job_schedule`(`id`) ON UPDATE no action ON DELETE cascade
|
||||
);
|
||||
--> statement-breakpoint
|
||||
INSERT INTO `__new_agent_channel_task`("channel_id", "task_id") SELECT "channel_id", "task_id" FROM `agent_channel_task`;--> statement-breakpoint
|
||||
DROP TABLE `agent_channel_task`;--> statement-breakpoint
|
||||
ALTER TABLE `__new_agent_channel_task` RENAME TO `agent_channel_task`;--> statement-breakpoint
|
||||
PRAGMA foreign_keys=ON;--> statement-breakpoint
|
||||
CREATE INDEX `agent_channel_task_channel_id_idx` ON `agent_channel_task` (`channel_id`);--> statement-breakpoint
|
||||
CREATE INDEX `agent_channel_task_task_id_idx` ON `agent_channel_task` (`task_id`);--> statement-breakpoint
|
||||
CREATE TABLE `__new_agent_session` (
|
||||
`id` text PRIMARY KEY NOT NULL,
|
||||
`agent_id` text,
|
||||
`name` text NOT NULL,
|
||||
`description` text DEFAULT '' NOT NULL,
|
||||
`workspace_id` text,
|
||||
`order_key` text NOT NULL,
|
||||
`created_at` integer NOT NULL,
|
||||
`updated_at` integer NOT NULL,
|
||||
FOREIGN KEY (`agent_id`) REFERENCES `agent`(`id`) ON UPDATE no action ON DELETE set null,
|
||||
FOREIGN KEY (`workspace_id`) REFERENCES `agent_workspace`(`id`) ON UPDATE no action ON DELETE set null
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DROP TABLE `agent_session`;--> statement-breakpoint
|
||||
ALTER TABLE `__new_agent_session` RENAME TO `agent_session`;--> statement-breakpoint
|
||||
CREATE INDEX `agent_session_order_key_idx` ON `agent_session` (`order_key`);--> statement-breakpoint
|
||||
CREATE TABLE `__new_agent` (
|
||||
`id` text PRIMARY KEY NOT NULL,
|
||||
`type` text NOT NULL,
|
||||
`name` text NOT NULL,
|
||||
`description` text DEFAULT '' NOT NULL,
|
||||
`instructions` text NOT NULL,
|
||||
`model` text,
|
||||
`plan_model` text,
|
||||
`small_model` text,
|
||||
`mcps` text DEFAULT '[]' NOT NULL,
|
||||
`allowed_tools` text DEFAULT '[]' NOT NULL,
|
||||
`configuration` text DEFAULT '{}' NOT NULL,
|
||||
`order_key` text NOT NULL,
|
||||
`created_at` integer NOT NULL,
|
||||
`updated_at` integer NOT NULL,
|
||||
`deleted_at` integer,
|
||||
FOREIGN KEY (`model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
|
||||
FOREIGN KEY (`plan_model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
|
||||
FOREIGN KEY (`small_model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DROP TABLE `agent`;--> statement-breakpoint
|
||||
ALTER TABLE `__new_agent` RENAME TO `agent`;--> statement-breakpoint
|
||||
CREATE INDEX `agent_name_idx` ON `agent` (`name`);--> statement-breakpoint
|
||||
CREATE INDEX `agent_type_idx` ON `agent` (`type`);--> statement-breakpoint
|
||||
CREATE INDEX `agent_order_key_idx` ON `agent` (`order_key`);--> statement-breakpoint
|
||||
CREATE TABLE `__new_agent_session_message` (
|
||||
`id` text PRIMARY KEY NOT NULL,
|
||||
`session_id` text NOT NULL,
|
||||
`role` text NOT NULL,
|
||||
`data` text NOT NULL,
|
||||
`searchable_text` text DEFAULT '' NOT NULL,
|
||||
`status` text NOT NULL,
|
||||
`model_id` text,
|
||||
`model_snapshot` text,
|
||||
`trace_id` text,
|
||||
`stats` text,
|
||||
`runtime_resume_token` text,
|
||||
`created_at` integer NOT NULL,
|
||||
`updated_at` integer NOT NULL,
|
||||
FOREIGN KEY (`session_id`) REFERENCES `agent_session`(`id`) ON UPDATE no action ON DELETE cascade,
|
||||
FOREIGN KEY (`model_id`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
|
||||
CONSTRAINT "agent_session_message_role_check" CHECK("__new_agent_session_message"."role" IN ('user', 'assistant', 'system')),
|
||||
CONSTRAINT "agent_session_message_status_check" CHECK("__new_agent_session_message"."status" IN ('pending', 'success', 'error', 'paused'))
|
||||
);
|
||||
--> statement-breakpoint
|
||||
DROP TABLE `agent_session_message`;--> statement-breakpoint
|
||||
ALTER TABLE `__new_agent_session_message` RENAME TO `agent_session_message`;--> statement-breakpoint
|
||||
CREATE INDEX `agent_session_message_session_created_id_idx` ON `agent_session_message` (`session_id`,`created_at`,`id`);--> statement-breakpoint
|
||||
ALTER TABLE `assistant` ADD `order_key` text DEFAULT '' NOT NULL;--> statement-breakpoint
|
||||
CREATE INDEX `assistant_order_key_idx` ON `assistant` (`order_key`);
|
||||
3626
migrations/sqlite-drizzle/meta/0003_snapshot.json
Normal file
3626
migrations/sqlite-drizzle/meta/0003_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,13 @@
|
||||
"when": 1780386366672,
|
||||
"tag": "0002_clever_skin",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 3,
|
||||
"version": "6",
|
||||
"when": 1780586439674,
|
||||
"tag": "0003_burly_tomorrow_man",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
19
package.json
19
package.json
@@ -90,7 +90,8 @@
|
||||
"ci": "pnpm ci:basic-check && pnpm ci:test-check"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/claude-agent-sdk": "0.2.112",
|
||||
"@anthropic-ai/claude-agent-sdk": "0.3.145",
|
||||
"@cherrystudio/ripgrep": "^1.0.0",
|
||||
"@expo/sudo-prompt": "^9.3.2",
|
||||
"@larksuiteoapi/node-sdk": "^1.59.0",
|
||||
"@libsql/client": "^0.15.15",
|
||||
@@ -100,7 +101,6 @@
|
||||
"@vectorstores/core": "^0.1.8",
|
||||
"@vectorstores/libsql": "workspace:*",
|
||||
"@vectorstores/readers": "^0.1.8",
|
||||
"cron-parser": "^5.0.8",
|
||||
"express": "5.1.0",
|
||||
"font-list": "2.0.0",
|
||||
"graceful-fs": "4.2.11",
|
||||
@@ -128,6 +128,7 @@
|
||||
"@ai-sdk/azure": "^3.0.54",
|
||||
"@ai-sdk/cerebras": "^2.0.45",
|
||||
"@ai-sdk/cohere": "^3.0.30",
|
||||
"@ai-sdk/devtools": "^0.0.17",
|
||||
"@ai-sdk/gateway": "^3.0.104",
|
||||
"@ai-sdk/google": "3.0.64",
|
||||
"@ai-sdk/google-vertex": "^4.0.112",
|
||||
@@ -139,13 +140,14 @@
|
||||
"@ai-sdk/perplexity": "^3.0.29",
|
||||
"@ai-sdk/provider": "^3.0.8",
|
||||
"@ai-sdk/provider-utils": "^4.0.23",
|
||||
"@ai-sdk/react": "^3.0.147",
|
||||
"@ai-sdk/test-server": "^1.0.3",
|
||||
"@ai-sdk/togetherai": "^2.0.45",
|
||||
"@ai-sdk/xai": "^3.0.83",
|
||||
"@ant-design/cssinjs": "1.23.0",
|
||||
"@ant-design/icons": "5.6.1",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/sdk": "^0.81.0",
|
||||
"@aws-sdk/client-s3": "^3.998.0",
|
||||
"@biomejs/biome": "2.2.4",
|
||||
"@changesets/changelog-github": "^0.5.2",
|
||||
@@ -221,7 +223,7 @@
|
||||
"@tailwindcss/vite": "^4.1.13",
|
||||
"@tanstack/react-query": "^5.85.5",
|
||||
"@tanstack/react-router": "^1.139.3",
|
||||
"@tanstack/react-virtual": "^3.13.12",
|
||||
"@tanstack/react-virtual": "^3.13.24",
|
||||
"@tanstack/router-plugin": "^1.139.3",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
@@ -409,7 +411,6 @@
|
||||
"react-hook-form": "^7.55.0",
|
||||
"react-hotkeys-hook": "^4.6.1",
|
||||
"react-i18next": "^14.1.2",
|
||||
"react-infinite-scroll-component": "^6.1.0",
|
||||
"react-json-view": "^1.21.3",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-player": "^3.3.1",
|
||||
@@ -544,6 +545,14 @@
|
||||
},
|
||||
"packageManager": "pnpm@10.27.0",
|
||||
"optionalDependencies": {
|
||||
"@anthropic-ai/claude-agent-sdk-darwin-arm64": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-darwin-x64": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-linux-arm64": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-linux-arm64-musl": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-linux-x64": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-linux-x64-musl": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-win32-arm64": "0.3.145",
|
||||
"@anthropic-ai/claude-agent-sdk-win32-x64": "0.3.145",
|
||||
"@img/sharp-darwin-arm64": "0.34.5",
|
||||
"@img/sharp-darwin-x64": "0.34.5",
|
||||
"@img/sharp-libvips-darwin-arm64": "1.2.4",
|
||||
|
||||
@@ -227,32 +227,6 @@ const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
])
|
||||
```
|
||||
|
||||
#### promptToolUsePlugin - 提示工具使用插件
|
||||
|
||||
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
|
||||
|
||||
```typescript
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
// 对于不支持 function call 的模型
|
||||
const executor = AiCore.create(
|
||||
'providerId',
|
||||
{
|
||||
apiKey: 'your-key',
|
||||
baseURL: 'https://your-model-endpoint'
|
||||
},
|
||||
[
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
// 可选:自定义系统提示符构建
|
||||
buildSystemPrompt: (userPrompt, tools) => {
|
||||
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
|
||||
}
|
||||
})
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 自定义插件
|
||||
|
||||
创建自定义插件非常简单:
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
"@ai-sdk/xai": "^3.0.83",
|
||||
"@cherrystudio/ai-sdk-provider": "workspace:*",
|
||||
"@openrouter/ai-sdk-provider": "^2.3.3",
|
||||
"lru-cache": "^11.2.4",
|
||||
"quick-lru": "^5.1.1",
|
||||
"zod": "^4.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -6,6 +6,5 @@
|
||||
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||
|
||||
// 执行管理
|
||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
export type { RuntimeConfig } from './runtime/types'
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
export * from './providerToolPlugin'
|
||||
export * from './toolUsePlugin/promptToolUsePlugin'
|
||||
export * from './toolUsePlugin/type'
|
||||
export * from './webSearchPlugin'
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
/**
|
||||
* 流事件管理器
|
||||
*
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai'
|
||||
|
||||
import type { AiSdkUsage } from '../../../providers/types'
|
||||
import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../types'
|
||||
import type { StreamController } from './ToolExecutor'
|
||||
|
||||
/**
|
||||
* 类型守卫:检查对象是否是有效的流结果(包含 ReadableStream 类型的 fullStream)
|
||||
*/
|
||||
function hasFullStream(obj: unknown): obj is StreamTextResult & { fullStream: ReadableStream } {
|
||||
return typeof obj === 'object' && obj !== null && 'fullStream' in obj && obj.fullStream instanceof ReadableStream
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 LanguageModelUsage
|
||||
* LanguageModelUsage 包含 totalTokens, inputTokens, outputTokens 等字段
|
||||
*/
|
||||
function isLanguageModelUsage(usage: unknown): usage is LanguageModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
('totalTokens' in usage || 'inputTokens' in usage || 'outputTokens' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 ImageModelUsage
|
||||
* ImageModelUsage 包含 inputTokens, outputTokens, totalTokens 字段
|
||||
* but lacks inputTokenDetails/outputTokenDetails which are present in LanguageModelUsage
|
||||
*/
|
||||
function isImageModelUsage(usage: unknown): usage is ImageModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
'inputTokens' in usage &&
|
||||
'outputTokens' in usage &&
|
||||
!('inputTokenDetails' in usage) &&
|
||||
!('outputTokenDetails' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 usage 是否是 EmbeddingModelUsage
|
||||
* EmbeddingModelUsage 只包含 tokens 字段
|
||||
*/
|
||||
function isEmbeddingModelUsage(usage: unknown): usage is EmbeddingModelUsage {
|
||||
return (
|
||||
typeof usage === 'object' &&
|
||||
usage !== null &&
|
||||
'tokens' in usage &&
|
||||
// 确保只有 tokens 字段(没有 inputTokens, outputTokens 等)
|
||||
!('inputTokens' in usage) &&
|
||||
!('outputTokens' in usage)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流事件管理器类
|
||||
*/
|
||||
export class StreamEventManager {
|
||||
/**
|
||||
* 发送工具调用步骤开始事件
|
||||
*/
|
||||
sendStepStartEvent(controller: StreamController): void {
|
||||
controller.enqueue({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(
|
||||
controller: StreamController,
|
||||
chunk: {
|
||||
usage?: Partial<AiSdkUsage>
|
||||
response?: { id: string; [key: string]: unknown }
|
||||
providerMetadata?: SharedV3ProviderMetadata
|
||||
},
|
||||
context: AiRequestContext,
|
||||
finishReason: string = 'stop'
|
||||
): void {
|
||||
// 累加当前步骤的 usage
|
||||
if (chunk.usage && context.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason,
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用并将结果流接入当前流
|
||||
*/
|
||||
async handleRecursiveCall<TParams extends StreamTextParams>(
|
||||
controller: StreamController,
|
||||
recursiveParams: Partial<TParams>,
|
||||
context: AiRequestContext<TParams, StreamTextResult>
|
||||
): Promise<void> {
|
||||
// try {
|
||||
// 重置工具执行状态,准备处理新的步骤
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (hasFullStream(recursiveResult)) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
// } catch (error) {
|
||||
// this.handleRecursiveCallError(controller, error, stepId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
if (value.type === 'start') {
|
||||
continue
|
||||
}
|
||||
|
||||
if (value.type === 'finish') {
|
||||
break
|
||||
}
|
||||
|
||||
controller.enqueue(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
*/
|
||||
buildRecursiveParams<TParams extends StreamTextParams>(
|
||||
context: AiRequestContext<TParams, StreamTextResult>,
|
||||
textBuffer: string,
|
||||
toolResultsText: string,
|
||||
tools: Record<string, unknown>
|
||||
): Partial<TParams> {
|
||||
const params = context.originalParams
|
||||
|
||||
// 构建新的对话消息
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(params.messages || []),
|
||||
// 只有当 textBuffer 有内容时才添加 assistant 消息,避免空消息导致 API 错误
|
||||
...(textBuffer ? [{ role: 'assistant' as const, content: textBuffer }] : []),
|
||||
{
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
}
|
||||
]
|
||||
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...params,
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
} as Partial<TParams>
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加 usage 数据
|
||||
*
|
||||
* 使用类型守卫来处理不同类型的 usage(LanguageModelUsage, ImageModelUsage, EmbeddingModelUsage)
|
||||
* - LanguageModelUsage: inputTokens, outputTokens, totalTokens
|
||||
* - ImageModelUsage: inputTokens, outputTokens, totalTokens
|
||||
* - EmbeddingModelUsage: tokens
|
||||
*/
|
||||
accumulateUsage(target: Partial<AiSdkUsage>, source: Partial<AiSdkUsage>): void {
|
||||
if (!target || !source) return
|
||||
|
||||
if (isLanguageModelUsage(target) && isLanguageModelUsage(source)) {
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
|
||||
// Accumulate inputTokenDetails (cacheReadTokens, cacheWriteTokens, noCacheTokens)
|
||||
if (source.inputTokenDetails) {
|
||||
if (!target.inputTokenDetails) {
|
||||
target.inputTokenDetails = {
|
||||
noCacheTokens: undefined,
|
||||
cacheReadTokens: undefined,
|
||||
cacheWriteTokens: undefined
|
||||
}
|
||||
}
|
||||
target.inputTokenDetails.cacheReadTokens =
|
||||
(target.inputTokenDetails.cacheReadTokens || 0) + (source.inputTokenDetails.cacheReadTokens || 0)
|
||||
target.inputTokenDetails.cacheWriteTokens =
|
||||
(target.inputTokenDetails.cacheWriteTokens || 0) + (source.inputTokenDetails.cacheWriteTokens || 0)
|
||||
target.inputTokenDetails.noCacheTokens =
|
||||
(target.inputTokenDetails.noCacheTokens || 0) + (source.inputTokenDetails.noCacheTokens || 0)
|
||||
}
|
||||
|
||||
// Accumulate outputTokenDetails (reasoningTokens, textTokens)
|
||||
if (source.outputTokenDetails) {
|
||||
if (!target.outputTokenDetails) {
|
||||
target.outputTokenDetails = { textTokens: undefined, reasoningTokens: undefined }
|
||||
}
|
||||
target.outputTokenDetails.reasoningTokens =
|
||||
(target.outputTokenDetails.reasoningTokens || 0) + (source.outputTokenDetails.reasoningTokens || 0)
|
||||
target.outputTokenDetails.textTokens =
|
||||
(target.outputTokenDetails.textTokens || 0) + (source.outputTokenDetails.textTokens || 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
if (isImageModelUsage(target) && isImageModelUsage(source)) {
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
return
|
||||
}
|
||||
|
||||
if (isEmbeddingModelUsage(target) && isEmbeddingModelUsage(source)) {
|
||||
target.tokens = (target.tokens || 0) + (source.tokens || 0)
|
||||
return
|
||||
}
|
||||
|
||||
// ⚠️ 未知类型或类型不匹配,不进行累加
|
||||
console.warn('[StreamEventManager] Unable to accumulate usage - type mismatch or unknown type', {
|
||||
target,
|
||||
source
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
/**
|
||||
* 工具执行器
|
||||
*
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet, TypedToolError } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具执行结果
|
||||
*/
|
||||
export interface ExecutedResult {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result: unknown
|
||||
isError?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 流控制器类型(从 AI SDK 提取)
|
||||
* Generic type parameter allows for type-safe chunk enqueuing
|
||||
*/
|
||||
export interface StreamController<TChunk = unknown> {
|
||||
enqueue(chunk: TChunk): void
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行器类
|
||||
*/
|
||||
export class ToolExecutor {
|
||||
/**
|
||||
* 执行多个工具调用
|
||||
*/
|
||||
async executeTools(
|
||||
toolUses: ToolUseResult[],
|
||||
tools: ToolSet,
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
if (!tool || typeof tool.execute !== 'function') {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
toolCallId: toolUse.id,
|
||||
messages: [],
|
||||
abortSignal: new AbortController().signal
|
||||
})
|
||||
|
||||
// 发送 tool-result 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
output: result
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result,
|
||||
isError: false
|
||||
})
|
||||
} catch (error) {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 处理错误情况
|
||||
const errorResult = this.handleToolError(toolUse, error, controller)
|
||||
executedResults.push(errorResult)
|
||||
}
|
||||
}
|
||||
|
||||
return executedResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化工具结果为 Cherry Studio 标准格式
|
||||
*/
|
||||
formatToolResults(executedResults: ExecutedResult[]): string {
|
||||
return executedResults
|
||||
.map((tr) => {
|
||||
if (!tr.isError) {
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||
} else {
|
||||
const error = tr.result || 'Unknown error'
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
|
||||
}
|
||||
})
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// // 发送 tool-input-start 事件
|
||||
// controller.enqueue({
|
||||
// type: 'tool-input-start',
|
||||
// id: toolUse.id,
|
||||
// toolName: toolUse.toolName
|
||||
// })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError<T extends ToolSet>(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
const toolError: TypedToolError<T> = {
|
||||
type: 'tool-error',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
error
|
||||
}
|
||||
|
||||
controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
// controller.enqueue({
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// error: error instanceof Error ? error.message : String(error),
|
||||
// input: toolUse.arguments
|
||||
// })
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error,
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,566 +0,0 @@
|
||||
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
|
||||
import { createMockContext, createMockTool } from '@test-utils'
|
||||
import type {
|
||||
EmbeddingModelUsage,
|
||||
ImageModelUsage,
|
||||
LanguageModelUsage as AiSdkUsage,
|
||||
LanguageModelUsage,
|
||||
TextStreamPart,
|
||||
ToolSet
|
||||
} from 'ai'
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { StreamEventManager } from '../StreamEventManager'
|
||||
import type { StreamController } from '../ToolExecutor'
|
||||
|
||||
/**
|
||||
* Type alias for empty toolset (no tools)
|
||||
* Using Record<string, never> ensures type safety for tests without tools
|
||||
*/
|
||||
type EmptyToolSet = Record<string, never>
|
||||
|
||||
/**
|
||||
* Mock StreamController for testing
|
||||
* Provides type-safe enqueue function that accepts TextStreamPart chunks
|
||||
*/
|
||||
interface MockStreamController<TOOLS extends ToolSet = EmptyToolSet> extends StreamController {
|
||||
enqueue: ReturnType<typeof vi.fn<(chunk: TextStreamPart<TOOLS>) => void>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a type-safe mock stream controller
|
||||
*/
|
||||
function createMockStreamController<TOOLS extends ToolSet = EmptyToolSet>(): MockStreamController<TOOLS> {
|
||||
return {
|
||||
enqueue: vi.fn()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for chunk data in finish-step events
|
||||
*/
|
||||
interface FinishStepChunk {
|
||||
usage?: Partial<AiSdkUsage>
|
||||
response?: { id: string; [key: string]: unknown }
|
||||
providerMetadata?: SharedV3ProviderMetadata
|
||||
}
|
||||
|
||||
describe('StreamEventManager', () => {
|
||||
let manager: StreamEventManager
|
||||
|
||||
beforeEach(() => {
|
||||
manager = new StreamEventManager()
|
||||
})
|
||||
|
||||
describe('accumulateUsage', () => {
|
||||
describe('LanguageModelUsage', () => {
|
||||
it('should accumulate language model usage correctly', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(30)
|
||||
expect(target.totalTokens).toBe(45)
|
||||
})
|
||||
|
||||
it('should handle undefined values in target', () => {
|
||||
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(10)
|
||||
expect(target.totalTokens).toBe(15)
|
||||
})
|
||||
|
||||
it('should handle undefined values in source', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = { inputTokens: 5 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(15)
|
||||
expect(target.outputTokens).toBe(20)
|
||||
expect(target.totalTokens).toBe(30)
|
||||
})
|
||||
|
||||
it('should handle zero values correctly', () => {
|
||||
const target: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
const source: Partial<LanguageModelUsage> = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(5)
|
||||
expect(target.outputTokens).toBe(10)
|
||||
expect(target.totalTokens).toBe(15)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ImageModelUsage', () => {
|
||||
it('should accumulate image model usage correctly', () => {
|
||||
const target: Partial<ImageModelUsage> = {
|
||||
inputTokens: 100,
|
||||
outputTokens: 50,
|
||||
totalTokens: 150
|
||||
}
|
||||
const source: Partial<ImageModelUsage> = {
|
||||
inputTokens: 50,
|
||||
outputTokens: 25,
|
||||
totalTokens: 75
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(150)
|
||||
expect(target.outputTokens).toBe(75)
|
||||
expect(target.totalTokens).toBe(225)
|
||||
})
|
||||
|
||||
it('should handle undefined values', () => {
|
||||
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
|
||||
const source: Partial<ImageModelUsage> = {
|
||||
outputTokens: 50,
|
||||
totalTokens: 50
|
||||
}
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.inputTokens).toBe(100)
|
||||
expect(target.outputTokens).toBe(50)
|
||||
expect(target.totalTokens).toBe(50)
|
||||
})
|
||||
})
|
||||
|
||||
describe('EmbeddingModelUsage', () => {
|
||||
it('should accumulate embedding model usage correctly', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 100 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(150)
|
||||
})
|
||||
|
||||
it('should handle zero to non-zero accumulation', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(50)
|
||||
})
|
||||
|
||||
it('should handle zero values', () => {
|
||||
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 100 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(target.tokens).toBe(100)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Guard Validation', () => {
|
||||
it('should warn on type mismatch between LanguageModelUsage and EmbeddingModelUsage', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 5 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Unable to accumulate usage'),
|
||||
expect.objectContaining({
|
||||
target,
|
||||
source
|
||||
})
|
||||
)
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should warn on type mismatch between ImageModelUsage and EmbeddingModelUsage', () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
|
||||
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
|
||||
|
||||
manager.accumulateUsage(target, source)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Unable to accumulate usage'), expect.any(Object))
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildRecursiveParams', () => {
|
||||
it('should include textBuffer in assistant message when not empty', () => {
|
||||
const context = createMockContext()
|
||||
const textBuffer = 'test response'
|
||||
const toolResultsText = '<tool_use_result>...</tool_use_result>'
|
||||
const tools = {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
|
||||
|
||||
expect(params.messages).toHaveLength(3)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
|
||||
expect(params.messages?.[1]).toEqual({ role: 'assistant', content: textBuffer })
|
||||
expect(params.messages?.[2]).toEqual({
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
})
|
||||
expect(params.tools).toBe(tools)
|
||||
})
|
||||
|
||||
it('should skip empty textBuffer in messages', () => {
|
||||
const context = createMockContext()
|
||||
const textBuffer = ''
|
||||
const toolResultsText = '<tool_use_result>...</tool_use_result>'
|
||||
const tools = {}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
|
||||
|
||||
// Should only have original user message and new user message with tool results
|
||||
expect(params.messages).toHaveLength(2)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
|
||||
expect(params.messages?.[1]).toEqual({
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
})
|
||||
|
||||
const assistantMessages = params.messages?.filter((m) => m.role === 'assistant')
|
||||
expect(assistantMessages).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should preserve all original messages', () => {
|
||||
const context = createMockContext({
|
||||
originalParams: {
|
||||
messages: [
|
||||
{ role: 'user', content: 'First message' },
|
||||
{ role: 'assistant', content: 'First response' },
|
||||
{ role: 'user', content: 'Second message' }
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
const params = manager.buildRecursiveParams(context, 'New response', 'Tool results', {})
|
||||
|
||||
expect(params.messages).toHaveLength(5)
|
||||
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'First message' })
|
||||
expect(params.messages?.[1]).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'First response'
|
||||
})
|
||||
expect(params.messages?.[2]).toEqual({ role: 'user', content: 'Second message' })
|
||||
expect(params.messages?.[3]).toEqual({ role: 'assistant', content: 'New response' })
|
||||
expect(params.messages?.[4]).toEqual({ role: 'user', content: 'Tool results' })
|
||||
})
|
||||
|
||||
it('should pass through tools parameter', () => {
|
||||
const context = createMockContext()
|
||||
const tools = {
|
||||
tool1: createMockTool('tool1'),
|
||||
tool2: createMockTool('tool2')
|
||||
}
|
||||
|
||||
const params = manager.buildRecursiveParams(context, 'response', 'results', tools)
|
||||
|
||||
expect(params.tools).toBe(tools)
|
||||
expect(Object.keys(params.tools!)).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('sendStepStartEvent', () => {
|
||||
it('should enqueue start-step event with correct structure', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
manager.sendStepStartEvent(controller)
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('sendStepFinishEvent', () => {
|
||||
it('should enqueue finish-step event with provided finishReason', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
},
|
||||
response: { id: 'test-response' },
|
||||
providerMetadata: { 'test-provider': {} }
|
||||
}
|
||||
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0
|
||||
}
|
||||
})
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith({
|
||||
type: 'finish-step',
|
||||
finishReason: 'tool-calls',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
})
|
||||
|
||||
it('should accumulate usage when provided', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
}
|
||||
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
})
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context)
|
||||
|
||||
// Verify accumulation happened
|
||||
expect(context.accumulatedUsage.inputTokens).toBe(15)
|
||||
expect(context.accumulatedUsage.outputTokens).toBe(30)
|
||||
expect(context.accumulatedUsage.totalTokens).toBe(45)
|
||||
})
|
||||
|
||||
it('should handle missing usage gracefully', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {}
|
||||
const context = createMockContext({
|
||||
accumulatedUsage: {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15
|
||||
}
|
||||
})
|
||||
|
||||
expect(() => manager.sendStepFinishEvent(controller, chunk, context)).not.toThrow()
|
||||
|
||||
// Verify accumulation did not change
|
||||
expect(context.accumulatedUsage.inputTokens).toBe(5)
|
||||
expect(context.accumulatedUsage.outputTokens).toBe(10)
|
||||
expect(context.accumulatedUsage.totalTokens).toBe(15)
|
||||
})
|
||||
|
||||
it('should use default finishReason of "stop" when not provided', () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const chunk: FinishStepChunk = {}
|
||||
const context = createMockContext()
|
||||
|
||||
manager.sendStepFinishEvent(controller, chunk, context)
|
||||
|
||||
expect(controller.enqueue).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
finishReason: 'stop'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleRecursiveCall', () => {
|
||||
it('should reset hasExecutedToolsInCurrentStep flag', async () => {
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
|
||||
chunks: [
|
||||
{
|
||||
type: 'text-delta',
|
||||
id: 'test-id',
|
||||
text: 'test'
|
||||
} as TextStreamPart<EmptyToolSet>
|
||||
],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
fullStream: mockStream
|
||||
})
|
||||
})
|
||||
|
||||
const params = { messages: [] }
|
||||
|
||||
await manager.handleRecursiveCall(controller, params, context)
|
||||
|
||||
expect(context.hasExecutedToolsInCurrentStep).toBe(false)
|
||||
expect(context.recursiveCall).toHaveBeenCalledWith(params)
|
||||
})
|
||||
|
||||
it('should pipe recursive stream to controller', async () => {
|
||||
const enqueuedChunks: TextStreamPart<EmptyToolSet>[] = []
|
||||
const controller = createMockStreamController()
|
||||
controller.enqueue.mockImplementation((chunk: TextStreamPart<EmptyToolSet>) => {
|
||||
enqueuedChunks.push(chunk)
|
||||
})
|
||||
|
||||
const mockChunks: TextStreamPart<EmptyToolSet>[] = [
|
||||
{ type: 'start' as const },
|
||||
{ type: 'start-step' as const, request: {}, warnings: [] },
|
||||
{ type: 'text-delta' as const, id: 'chunk-1', text: 'recursive' },
|
||||
{ type: 'text-delta' as const, id: 'chunk-2', text: ' response' },
|
||||
{
|
||||
type: 'finish-step' as const,
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
response: {
|
||||
id: 'test-response-id',
|
||||
timestamp: new Date(),
|
||||
modelId: 'test-model'
|
||||
},
|
||||
usage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
},
|
||||
providerMetadata: undefined
|
||||
},
|
||||
{
|
||||
type: 'finish' as const,
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
totalUsage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
|
||||
chunks: mockChunks,
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
fullStream: mockStream
|
||||
})
|
||||
})
|
||||
|
||||
await manager.handleRecursiveCall(controller, {}, context)
|
||||
|
||||
// Should skip 'start' type and stop at 'finish' type
|
||||
expect(enqueuedChunks).toHaveLength(4)
|
||||
expect(enqueuedChunks[0]).toEqual({ type: 'start-step', request: {}, warnings: [] })
|
||||
expect(enqueuedChunks[1]).toEqual({ type: 'text-delta', id: 'chunk-1', text: 'recursive' })
|
||||
expect(enqueuedChunks[2]).toEqual({ type: 'text-delta', id: 'chunk-2', text: ' response' })
|
||||
expect(enqueuedChunks[3]).toMatchObject({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
rawFinishReason: 'stop',
|
||||
providerMetadata: undefined,
|
||||
usage: {
|
||||
totalTokens: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
inputTokenDetails: {
|
||||
noCacheTokens: 0,
|
||||
cacheReadTokens: 0,
|
||||
cacheWriteTokens: 0
|
||||
},
|
||||
outputTokenDetails: {
|
||||
textTokens: 0,
|
||||
reasoningTokens: 0
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should warn when no fullStream is found', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn')
|
||||
const controller = createMockStreamController()
|
||||
|
||||
const context = createMockContext({
|
||||
hasExecutedToolsInCurrentStep: true,
|
||||
recursiveCall: vi.fn().mockResolvedValue({
|
||||
// No fullStream property
|
||||
someOtherProperty: 'value'
|
||||
})
|
||||
})
|
||||
|
||||
await manager.handleRecursiveCall(controller, {}, context)
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP Prompt] No fullstream found'),
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,559 +0,0 @@
|
||||
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import { simulateReadableStream } from 'ai'
|
||||
import { convertReadableStreamToArray } from 'ai/test'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
|
||||
|
||||
describe('promptToolUsePlugin', () => {
|
||||
describe('Factory Function', () => {
|
||||
it('should return AiPlugin with correct name', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
|
||||
expect(plugin.name).toBe('built-in:prompt-tool-use')
|
||||
expect(plugin.transformParams).toBeDefined()
|
||||
expect(plugin.transformStream).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept empty configuration', () => {
|
||||
const plugin = createPromptToolUsePlugin({})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
expect(plugin.name).toBe('built-in:prompt-tool-use')
|
||||
})
|
||||
|
||||
it('should accept custom buildSystemPrompt', () => {
|
||||
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string) => userSystemPrompt)
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
buildSystemPrompt: customBuildSystemPrompt
|
||||
})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept custom parseToolUse', () => {
|
||||
const customParseToolUse = vi.fn(() => ({ results: [], content: '' }))
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
parseToolUse: customParseToolUse
|
||||
})
|
||||
|
||||
expect(plugin).toBeDefined()
|
||||
})
|
||||
|
||||
it('should accept enabled flag', () => {
|
||||
const pluginDisabled = createPromptToolUsePlugin({ enabled: false })
|
||||
const pluginEnabled = createPromptToolUsePlugin({ enabled: true })
|
||||
|
||||
expect(pluginDisabled).toBeDefined()
|
||||
expect(pluginEnabled).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('transformParams', () => {
|
||||
it('should separate provider and prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
provider_tool: 'provider',
|
||||
prompt_tool: 'function'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
// Provider tools should remain in tools
|
||||
expect(result.tools).toBeDefined()
|
||||
expect(result.tools).toHaveProperty('provider_tool')
|
||||
expect(result.tools).not.toHaveProperty('prompt_tool')
|
||||
|
||||
// Prompt tools should be moved to context.mcpTools
|
||||
expect(context.mcpTools).toBeDefined()
|
||||
expect(context.mcpTools).toHaveProperty('prompt_tool')
|
||||
expect(context.mcpTools).not.toHaveProperty('provider_tool')
|
||||
})
|
||||
|
||||
it('should handle only provider tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
provider_tool1: 'provider',
|
||||
provider_tool2: 'provider'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.tools).toEqual(params.tools)
|
||||
expect(context.mcpTools).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle only prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: createMockToolSet({
|
||||
prompt_tool1: 'function',
|
||||
prompt_tool2: 'function'
|
||||
})
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.tools).toBeUndefined()
|
||||
expect(context.mcpTools).toEqual(params.tools)
|
||||
})
|
||||
|
||||
it('should build system prompt for prompt tools', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original system prompt',
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool', 'Test tool description')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.system).toBeDefined()
|
||||
expect(typeof result.system).toBe('string')
|
||||
expect(result.system).toContain('In this environment you have access to a set of tools')
|
||||
expect(result.system).toContain('test_tool')
|
||||
expect(result.system).toContain('Test tool description')
|
||||
expect(result.system).toContain('Original system prompt')
|
||||
})
|
||||
|
||||
it('should handle empty user system prompt', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result.system).toBeDefined()
|
||||
expect(result.system).toContain('In this environment you have access to a set of tools')
|
||||
})
|
||||
|
||||
it('should skip system prompt when disabled', async () => {
|
||||
const plugin = createPromptToolUsePlugin({ enabled: false })
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
expect(context.mcpTools).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should skip when no tools provided', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original'
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
})
|
||||
|
||||
it('should skip when tools is not an object', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: 'invalid' as any
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(result).toEqual(params)
|
||||
})
|
||||
|
||||
it('should use custom buildSystemPrompt when provided', async () => {
|
||||
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string, tools: ToolSet) => {
|
||||
return `Custom prompt with ${Object.keys(tools).length} tools and user prompt: ${userSystemPrompt}`
|
||||
})
|
||||
|
||||
const plugin = createPromptToolUsePlugin({
|
||||
buildSystemPrompt: customBuildSystemPrompt
|
||||
})
|
||||
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'User prompt',
|
||||
tools: {
|
||||
tool1: createMockTool('tool1')
|
||||
}
|
||||
})
|
||||
|
||||
const result = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(customBuildSystemPrompt).toHaveBeenCalled()
|
||||
expect(result.system).toBe('Custom prompt with 1 tools and user prompt: User prompt')
|
||||
})
|
||||
|
||||
it('should save originalParams to context', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'Original',
|
||||
tools: {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
})
|
||||
|
||||
await Promise.resolve(plugin.transformParams!(params, context))
|
||||
|
||||
expect(context.originalParams).toBeDefined()
|
||||
expect(context.originalParams.system).toBeDefined()
|
||||
})
|
||||
|
||||
it('should NOT rebuild system prompt on recursive call', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
const params = createMockStreamParams({
|
||||
system: 'User system prompt',
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool', 'A test tool')
|
||||
}
|
||||
})
|
||||
|
||||
// First call: build the system prompt with tools
|
||||
const firstResult = await Promise.resolve(plugin.transformParams!(params, context))
|
||||
const firstSystemPrompt = firstResult.system as string
|
||||
|
||||
// Verify first call includes tool definitions
|
||||
expect(firstSystemPrompt).toContain('test_tool')
|
||||
|
||||
// Simulate recursive call: isRecursiveCall is true
|
||||
context.isRecursiveCall = true
|
||||
|
||||
const recursiveParams = createMockStreamParams({
|
||||
system: firstSystemPrompt,
|
||||
tools: {
|
||||
test_tool: createMockTool('test_tool', 'A test tool')
|
||||
}
|
||||
})
|
||||
|
||||
const recursiveResult = await Promise.resolve(plugin.transformParams!(recursiveParams, context))
|
||||
|
||||
// System prompt should NOT be rebuilt - it should remain the same
|
||||
expect(recursiveResult.system).toBe(firstSystemPrompt)
|
||||
|
||||
// Count occurrences of tool definition to ensure no duplication
|
||||
const toolOccurrences = (recursiveResult.system as string).split('test_tool').length - 1
|
||||
const firstToolOccurrences = firstSystemPrompt.split('test_tool').length - 1
|
||||
expect(toolOccurrences).toBe(firstToolOccurrences)
|
||||
})
|
||||
})
|
||||
|
||||
describe('transformStream', () => {
|
||||
it('should return identity transform when disabled', async () => {
|
||||
const plugin = createPromptToolUsePlugin({ enabled: false })
|
||||
const context = createMockContext()
|
||||
|
||||
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
|
||||
{ type: 'text-delta', text: 'Hello' },
|
||||
{ type: 'text-delta', text: ' World' }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result).toEqual(inputChunks)
|
||||
})
|
||||
|
||||
it('should return identity transform when no mcpTools in context', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
// Don't set context.mcpTools
|
||||
|
||||
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
|
||||
{ type: 'text-delta', text: 'Hello' },
|
||||
{ type: 'text-delta', text: ' World' }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result).toEqual(inputChunks)
|
||||
})
|
||||
|
||||
it('should initialize accumulatedUsage in context', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
plugin.transformStream!(createMockStreamParams(), context)()
|
||||
|
||||
expect(context.accumulatedUsage).toBeDefined()
|
||||
expect(context.accumulatedUsage).toEqual({
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
})
|
||||
})
|
||||
|
||||
it('should filter tool tags from text-delta chunks', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: 'Before ' },
|
||||
{ type: 'text-delta' as const, text: '<tool_use>' },
|
||||
{ type: 'text-delta' as const, text: '<name>test</name>' },
|
||||
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
|
||||
{ type: 'text-delta' as const, text: '</tool_use>' },
|
||||
{ type: 'text-delta' as const, text: ' After' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Extract text from text-delta chunks
|
||||
const textChunks = result.filter((chunk) => chunk.type === 'text-delta')
|
||||
const fullText = textChunks.map((chunk) => 'text' in chunk && chunk.text).join('')
|
||||
|
||||
// Tool tags should be filtered out
|
||||
expect(fullText).not.toContain('<tool_use>')
|
||||
expect(fullText).not.toContain('</tool_use>')
|
||||
expect(fullText).toContain('Before')
|
||||
expect(fullText).toContain('After')
|
||||
})
|
||||
|
||||
it('should hold text-start until non-tag content appears', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
// Only tool tags, no actual content
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: '<tool_use>' },
|
||||
{ type: 'text-delta' as const, text: '<name>test</name>' },
|
||||
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
|
||||
{ type: 'text-delta' as const, text: '</tool_use>' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Should not have text-start or text-end since all content was tool tags
|
||||
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(false)
|
||||
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(false)
|
||||
})
|
||||
|
||||
it('should send text-start when non-tag content appears', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{ type: 'text-start' as const },
|
||||
{ type: 'text-delta' as const, text: 'Actual content' },
|
||||
{ type: 'text-end' as const }
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Should have text-start, text-delta, and text-end
|
||||
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(true)
|
||||
expect(result.some((chunk) => chunk.type === 'text-delta')).toBe(true)
|
||||
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(true)
|
||||
})
|
||||
|
||||
it('should pass through non-text events', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const stepStartEvent = { type: 'start-step' as const, request: {}, warnings: [] }
|
||||
|
||||
const inputChunks = [stepStartEvent]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
expect(result[0]).toEqual(stepStartEvent)
|
||||
})
|
||||
|
||||
it('should accumulate usage from finish-step events', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{
|
||||
type: 'finish-step' as const,
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
inputTokens: 10,
|
||||
outputTokens: 20,
|
||||
totalTokens: 30
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
// Verify usage was accumulated
|
||||
expect(context.accumulatedUsage).toBeDefined()
|
||||
expect(context.accumulatedUsage!.inputTokens).toBe(10)
|
||||
expect(context.accumulatedUsage!.outputTokens).toBe(20)
|
||||
expect(context.accumulatedUsage!.totalTokens).toBe(30)
|
||||
})
|
||||
|
||||
it('should include accumulated usage in finish event', async () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
const context = createMockContext()
|
||||
context.mcpTools = {
|
||||
test: createMockTool('test')
|
||||
}
|
||||
|
||||
// Pre-populate accumulated usage
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 5,
|
||||
outputTokens: 10,
|
||||
totalTokens: 15,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
|
||||
const inputChunks = [
|
||||
{
|
||||
type: 'finish' as const,
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
inputTokens: 100,
|
||||
outputTokens: 200,
|
||||
totalTokens: 300
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
|
||||
chunks: inputChunks as unknown as TextStreamPart<ToolSet>[],
|
||||
initialDelayInMs: 0,
|
||||
chunkDelayInMs: 0
|
||||
})
|
||||
|
||||
const transform = plugin.transformStream!(createMockStreamParams(), context)()
|
||||
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
|
||||
|
||||
const finishEvent = result.find((chunk) => chunk.type === 'finish')
|
||||
expect(finishEvent).toBeDefined()
|
||||
if (finishEvent && 'totalUsage' in finishEvent) {
|
||||
expect(finishEvent.totalUsage).toEqual(context.accumulatedUsage)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should have correct generic parameters for StreamTextParams and StreamTextResult', () => {
|
||||
const plugin = createPromptToolUsePlugin()
|
||||
|
||||
// Type assertion to verify the plugin has the correct type
|
||||
type PluginType = typeof plugin
|
||||
const typeTest: PluginType = plugin
|
||||
|
||||
expect(typeTest.name).toBe('built-in:prompt-tool-use')
|
||||
})
|
||||
})
|
||||
|
||||
describe('DEFAULT_SYSTEM_PROMPT', () => {
|
||||
it('should contain required sections', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Formatting')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Rules')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('Response rules')
|
||||
})
|
||||
|
||||
it('should have placeholders for dynamic content', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ TOOLS_INFO }}')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ USER_SYSTEM_PROMPT }}')
|
||||
})
|
||||
|
||||
it('should contain XML tag examples', () => {
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<tool_use>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('</tool_use>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<name>')
|
||||
expect(DEFAULT_SYSTEM_PROMPT).toContain('<arguments>')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,487 +0,0 @@
|
||||
/**
|
||||
* 内置插件:MCP Prompt 模式
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiPlugin, StreamTextParams, StreamTextResult } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { type TagConfig, TagExtractor } from './tagExtraction'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import type { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具使用标签配置
|
||||
*/
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
|
||||
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||
|
||||
<tool_use>
|
||||
<name>{tool_name}</name>
|
||||
<arguments>{json_arguments}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. IMPORTANT: When writing JSON inside the <arguments> tag, any double quotes inside string values must be escaped with a backslash ("). For example:
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{ "query": "browser,fetch" }</arguments>
|
||||
</tool_use>
|
||||
|
||||
<tool_use>
|
||||
<name>exec</name>
|
||||
<arguments>{ "code": "const page = await CherryBrowser_fetch({ url: \\"https://example.com\\" })\nreturn page" }</arguments>
|
||||
</tool_use>
|
||||
|
||||
|
||||
The user will respond with the result of the tool use, which should be formatted as follows:
|
||||
|
||||
<tool_use_result>
|
||||
<name>{tool_name}</name>
|
||||
<result>{result}</result>
|
||||
</tool_use_result>
|
||||
|
||||
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
|
||||
For example, if the result of the tool use is an image file, you can use it in the next action like this:
|
||||
|
||||
<tool_use>
|
||||
<name>image_transformer</name>
|
||||
<arguments>{"image": "image_1.jpg"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||
|
||||
## Tool Use Rules
|
||||
Here are the rules you should always follow to solve your task:
|
||||
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
|
||||
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||
3. If no tool call is needed, just answer the question directly.
|
||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||
|
||||
{{ TOOLS_INFO }}
|
||||
|
||||
## Response rules
|
||||
|
||||
Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used.
|
||||
|
||||
# User Instructions
|
||||
{{ USER_SYSTEM_PROMPT }}
|
||||
`
|
||||
|
||||
/**
|
||||
* 默认工具使用示例(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_TOOL_USE_EXAMPLES = `
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
User: Generate an image of the oldest person in this document.
|
||||
|
||||
A: I can use the document_qa tool to find out who the oldest person is in the document.
|
||||
<tool_use>
|
||||
<name>document_qa</name>
|
||||
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>document_qa</name>
|
||||
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the image_generator tool to create a portrait of John Doe.
|
||||
<tool_use>
|
||||
<name>image_generator</name>
|
||||
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>image_generator</name>
|
||||
<result>image.png</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: the image is generated as image.png
|
||||
|
||||
---
|
||||
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
A: I can use the python_interpreter tool to calculate the result of the operation.
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>python_interpreter</name>
|
||||
<result>1302.678</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: The result of the operation is 1302.678.
|
||||
|
||||
---
|
||||
User: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
A: I can use the search tool to find the population of Guangzhou.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Guangzhou"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the search tool to find the population of Shanghai.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Shanghai"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>26 million (2019)</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
|
||||
|
||||
/**
|
||||
* 构建可用工具部分(提取自 Cherry Studio)
|
||||
*/
|
||||
function buildAvailableTools(tools: ToolSet): string | null {
|
||||
const availableTools = Object.keys(tools)
|
||||
if (availableTools.length === 0) return null
|
||||
const result = availableTools
|
||||
.map((toolName: string) => {
|
||||
const tool = tools[toolName]
|
||||
return `
|
||||
<tool>
|
||||
<name>${toolName}</name>
|
||||
<description>${tool.description || ''}</description>
|
||||
<arguments>
|
||||
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||
</arguments>
|
||||
</tool>
|
||||
`
|
||||
})
|
||||
.join('\n')
|
||||
return `<tools>
|
||||
${result}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||
*/
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet, mcpMode?: string): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
if (availableTools === null) return userSystemPrompt
|
||||
|
||||
if (mcpMode == 'auto') {
|
||||
return DEFAULT_SYSTEM_PROMPT.replace('{{ TOOLS_INFO }}', '').replace(
|
||||
'{{ USER_SYSTEM_PROMPT }}',
|
||||
userSystemPrompt || ''
|
||||
)
|
||||
}
|
||||
const toolsInfo = `
|
||||
## Tool Use Examples
|
||||
{{ TOOL_USE_EXAMPLES }}
|
||||
|
||||
## Tool Use Available Tools
|
||||
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||
{{ AVAILABLE_TOOLS }}`
|
||||
.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOLS_INFO }}', toolsInfo).replace(
|
||||
'{{ USER_SYSTEM_PROMPT }}',
|
||||
userSystemPrompt || ''
|
||||
)
|
||||
|
||||
return fullPrompt
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认工具解析函数(提取自 Cherry Studio)
|
||||
* 解析 XML 格式的工具调用
|
||||
*/
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||
return { results: [], content: content }
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const results: ToolUseResult[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
const fullMatch = match[0]
|
||||
let toolName = match[2].trim()
|
||||
switch (toolName.toLowerCase()) {
|
||||
case 'search':
|
||||
toolName = 'mcp__CherryHub__search'
|
||||
break
|
||||
case 'exec':
|
||||
toolName = 'mcp__CherryHub__exec'
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
const toolArgs = match[4].trim()
|
||||
|
||||
// Try to parse the arguments as JSON
|
||||
let parsedArgs
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolArgs)
|
||||
} catch (error) {
|
||||
// If parsing fails, use the string as is
|
||||
parsedArgs = toolArgs
|
||||
}
|
||||
|
||||
// Find the corresponding tool
|
||||
const tool = tools[toolName]
|
||||
if (!tool) {
|
||||
console.warn(`Tool "${toolName}" not found in available tools`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to results array
|
||||
results.push({
|
||||
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||
toolName: toolName,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
})
|
||||
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||
}
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (
|
||||
config: PromptToolUseConfig = {}
|
||||
): AiPlugin<StreamTextParams, StreamTextResult> => {
|
||||
const {
|
||||
enabled = true,
|
||||
buildSystemPrompt = defaultBuildSystemPrompt,
|
||||
parseToolUse = defaultParseToolUse,
|
||||
mcpMode
|
||||
} = config
|
||||
|
||||
return definePlugin<StreamTextParams, StreamTextResult>({
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params, context) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
// 分离 provider 和其他类型的工具
|
||||
const providerDefinedTools: ToolSet = {}
|
||||
const promptTools: ToolSet = {}
|
||||
|
||||
for (const [toolName, tool] of Object.entries(params.tools)) {
|
||||
if (tool.type === 'provider') {
|
||||
// provider 类型的工具保留在 tools 参数中
|
||||
providerDefinedTools[toolName] = tool
|
||||
} else {
|
||||
// 其他工具转换为 prompt 模式
|
||||
promptTools[toolName] = tool
|
||||
}
|
||||
}
|
||||
|
||||
// 只有当有非 provider 工具时才保存到 context
|
||||
if (Object.keys(promptTools).length > 0) {
|
||||
context.mcpTools = promptTools
|
||||
}
|
||||
|
||||
// 递归调用时,不重新构建 system prompt,避免重复追加工具定义
|
||||
if (context.isRecursiveCall) {
|
||||
const transformedParams = {
|
||||
...params,
|
||||
tools: Object.keys(providerDefinedTools).length > 0 ? providerDefinedTools : undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
return transformedParams
|
||||
}
|
||||
|
||||
// 构建系统提示符(只包含非 provider 工具)
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, promptTools, mcpMode)
|
||||
|
||||
// 保留 provide tools,移除其他 tools
|
||||
const transformedParams = {
|
||||
...params,
|
||||
...(systemPrompt ? { system: systemPrompt } : {}),
|
||||
tools: Object.keys(providerDefinedTools).length > 0 ? providerDefinedTools : undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_, context) => () => {
|
||||
let textBuffer = ''
|
||||
// let stepId = ''
|
||||
|
||||
// 如果没有需要 prompt 模式处理的工具,直接返回原始流
|
||||
if (!context.mcpTools) {
|
||||
return new TransformStream()
|
||||
}
|
||||
|
||||
// 初始化 usage 累加器和工具执行状态
|
||||
if (!context.accumulatedUsage) {
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
}
|
||||
if (context.hasExecutedToolsInCurrentStep === undefined) {
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
}
|
||||
|
||||
// 创建工具执行器、流事件管理器和标签提取器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
// 用于hold text-start事件,直到确认有非工具标签内容
|
||||
let pendingTextStart: TextStreamPart<TOOLS> | null = null
|
||||
let hasStartedText = false
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
async transform(
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// Hold住text-start事件,直到确认有非工具标签内容
|
||||
if ((chunk as any).type === 'text-start') {
|
||||
pendingTextStart = chunk
|
||||
return
|
||||
}
|
||||
|
||||
// text-delta阶段:收集文本内容并过滤工具标签
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
// stepId = chunk.id || ''
|
||||
|
||||
// 使用TagExtractor过滤工具标签,只传递非标签内容到UI层
|
||||
const extractionResults = tagExtractor.processText(chunk.text || '')
|
||||
|
||||
for (const result of extractionResults) {
|
||||
// 只传递非标签内容到UI层
|
||||
if (!result.isTagContent && result.content) {
|
||||
// 如果还没有发送text-start且有pending的text-start,先发送它
|
||||
if (!hasStartedText && pendingTextStart) {
|
||||
controller.enqueue(pendingTextStart)
|
||||
hasStartedText = true
|
||||
pendingTextStart = null
|
||||
}
|
||||
|
||||
const filteredChunk = {
|
||||
...chunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(filteredChunk)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
// 只有当已经发送了text-start时才发送text-end
|
||||
if (hasStartedText) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'finish-step') {
|
||||
// 统一在finish-step阶段检查并执行工具调用
|
||||
const tools = context.mcpTools
|
||||
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
|
||||
// 解析完整的textBuffer来检测工具调用
|
||||
const { results: parsedTools } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
if (validToolUses.length > 0) {
|
||||
context.hasExecutedToolsInCurrentStep = true
|
||||
|
||||
// 执行工具调用(不需要手动发送 start-step,外部流已经处理)
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||
|
||||
// 处理递归调用
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有执行工具调用,累加 usage 后透传 finish-step 事件
|
||||
if (chunk.usage && context.accumulatedUsage) {
|
||||
streamEventManager.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
||||
}
|
||||
controller.enqueue(chunk)
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 finish 类型,使用累加后的 totalUsage
|
||||
if (chunk.type === 'finish') {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
totalUsage: context.accumulatedUsage
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递(不包括text-start,已在上面处理)
|
||||
if ((chunk as any).type !== 'text-start') {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 清理pending状态
|
||||
pendingTextStart = null
|
||||
hasStartedText = false
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||
|
||||
/**
|
||||
* Returns the index of the start of the searchedText in the text, or null if it
|
||||
* is not found.
|
||||
*/
|
||||
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||
// Return null immediately if searchedText is empty.
|
||||
if (searchedText.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the searchedText exists as a direct substring of text.
|
||||
const directIndex = text.indexOf(searchedText)
|
||||
if (directIndex !== -1) {
|
||||
return directIndex
|
||||
}
|
||||
|
||||
// Otherwise, look for the largest suffix of "text" that matches
|
||||
// a prefix of "searchedText". We go from the end of text inward.
|
||||
for (let i = text.length - 1; i >= 0; i--) {
|
||||
const suffix = text.substring(i)
|
||||
if (searchedText.startsWith(suffix)) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
import type { AiRequestContext } from '../..'
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
export interface BaseToolUsePluginConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||
mcpMode?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface ToolUseRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
@@ -21,7 +21,6 @@ export interface AiRequestMetadata {
|
||||
enableReasoning?: boolean
|
||||
enableWebSearch?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
isPromptToolUse?: boolean
|
||||
isSupportedToolUse?: boolean
|
||||
// 自定义元数据,使用 JSONValue 确保类型安全
|
||||
custom?: JSONObject
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { ProviderV3 } from '@ai-sdk/provider'
|
||||
import { LRUCache } from 'lru-cache'
|
||||
import QuickLRU from 'quick-lru'
|
||||
|
||||
import { deepMergeObjects } from '../../utils'
|
||||
import type { ProviderVariant, ToolFactoryMap } from '../types'
|
||||
@@ -116,7 +116,7 @@ export class ProviderExtension<
|
||||
>
|
||||
> {
|
||||
/** Provider 实例缓存 - 按 settings hash 存储,LRU 自动清理 */
|
||||
private instances: LRUCache<string, TProvider>
|
||||
private instances: QuickLRU<string, TProvider>
|
||||
|
||||
/** In-flight promise map - 防止并发创建相同 settings 的 provider */
|
||||
private pendingCreations: Map<string, Promise<TProvider>> = new Map()
|
||||
@@ -126,9 +126,8 @@ export class ProviderExtension<
|
||||
throw new Error('ProviderExtension: name is required')
|
||||
}
|
||||
|
||||
this.instances = new LRUCache<string, TProvider>({
|
||||
max: 10,
|
||||
updateAgeOnGet: true
|
||||
this.instances = new QuickLRU<string, TProvider>({
|
||||
maxSize: 10
|
||||
})
|
||||
}
|
||||
|
||||
@@ -161,10 +160,13 @@ export class ProviderExtension<
|
||||
return 'default'
|
||||
}
|
||||
|
||||
const seen = new WeakSet()
|
||||
const stableStringify = (obj: any): string => {
|
||||
if (obj === null || obj === undefined) return 'null'
|
||||
if (typeof obj === 'function') return '"[function]"'
|
||||
if (typeof obj !== 'object') return JSON.stringify(obj)
|
||||
if (seen.has(obj)) return '"[circular]"'
|
||||
seen.add(obj)
|
||||
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
|
||||
|
||||
const keys = Object.keys(obj).sort()
|
||||
@@ -324,10 +326,10 @@ export class ProviderExtension<
|
||||
* 获取已缓存的 provider 实例(如果存在)
|
||||
*/
|
||||
getCachedProvider(): TProvider | undefined {
|
||||
for (const [key, value] of this.instances.entries()) {
|
||||
for (const [key, value] of this.instances) {
|
||||
if (!key.includes(':')) return value
|
||||
}
|
||||
for (const [, value] of this.instances.entries()) {
|
||||
for (const [, value] of this.instances) {
|
||||
return value
|
||||
}
|
||||
return undefined
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
|
||||
import type { embedMany, generateImage, generateText, streamText } from 'ai'
|
||||
import type { embedMany, Experimental_DownloadFunction, generateImage, generateText, streamText } from 'ai'
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
|
||||
@@ -31,6 +31,7 @@ export interface RuntimeConfig<
|
||||
|
||||
export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'model'> & {
|
||||
model: string | ImageModelV3
|
||||
experimental_download?: Experimental_DownloadFunction
|
||||
}
|
||||
export type generateImageResult = Awaited<ReturnType<typeof generateImage>>
|
||||
export type generateTextParams = Parameters<typeof generateText>[0]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { W3CTraceContextPropagator } from '@opentelemetry/core'
|
||||
import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http'
|
||||
import type { SpanProcessor } from '@opentelemetry/sdk-trace-base'
|
||||
import { BatchSpanProcessor, ConsoleSpanExporter } from '@opentelemetry/sdk-trace-base'
|
||||
import { ConsoleSpanExporter, SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base'
|
||||
import { WebTracerProvider } from '@opentelemetry/sdk-trace-web'
|
||||
|
||||
import type { TraceConfig } from '../trace-core/types/config'
|
||||
@@ -21,7 +20,10 @@ export class WebTracer {
|
||||
defaultConfig.headers = config.headers || defaultConfig.headers
|
||||
defaultConfig.defaultTracerName = config.defaultTracerName || defaultConfig.defaultTracerName
|
||||
}
|
||||
this.processor = spanProcessor || new BatchSpanProcessor(this.getExporter())
|
||||
// Callers are expected to pass a processor. The dev-only fallback logs
|
||||
// spans to the console so that a misconfigured caller doesn't silently
|
||||
// lose data when callers forget to inject a processor.
|
||||
this.processor = spanProcessor || new SimpleSpanProcessor(new ConsoleSpanExporter())
|
||||
this.provider = new WebTracerProvider({
|
||||
spanProcessors: [this.processor]
|
||||
})
|
||||
@@ -30,16 +32,6 @@ export class WebTracer {
|
||||
contextManager: contextManager
|
||||
})
|
||||
}
|
||||
|
||||
private static getExporter() {
|
||||
if (defaultConfig.endpoint) {
|
||||
return new OTLPTraceExporter({
|
||||
url: `${defaultConfig.endpoint}/v1/traces`,
|
||||
headers: defaultConfig.headers
|
||||
})
|
||||
}
|
||||
return new ConsoleSpanExporter()
|
||||
}
|
||||
}
|
||||
|
||||
export const startContext = contextManager.startContextForTopic.bind(contextManager)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,8 @@
|
||||
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { buildRuntimeEndpointConfigs, lookupRegistryModel } from '../registry-utils'
|
||||
import { buildRuntimeEndpointConfigs, inferAdapterFamily, lookupRegistryModel } from '../registry-utils'
|
||||
import { ENDPOINT_TYPE } from '../schemas/enums'
|
||||
import type { ModelConfig } from '../schemas/model'
|
||||
import type { RegistryEndpointConfig } from '../schemas/provider'
|
||||
import type { ProviderModelOverride } from '../schemas/provider-models'
|
||||
@@ -187,4 +188,58 @@ describe('buildRuntimeEndpointConfigs', () => {
|
||||
} as Record<string, RegistryEndpointConfig>)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('copies adapterFamily through to runtime config', () => {
|
||||
const result = buildRuntimeEndpointConfigs({
|
||||
'openai-chat-completions': { baseUrl: 'https://x', adapterFamily: 'openai-compatible' },
|
||||
'anthropic-messages': { baseUrl: 'https://y', adapterFamily: 'anthropic' }
|
||||
} as Record<string, RegistryEndpointConfig>)
|
||||
|
||||
expect(result!['openai-chat-completions'].adapterFamily).toBe('openai-compatible')
|
||||
expect(result!['anthropic-messages'].adapterFamily).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('adapterFamily alone is enough to retain an endpoint config', () => {
|
||||
const result = buildRuntimeEndpointConfigs({
|
||||
'openai-chat-completions': { adapterFamily: 'openai-compatible' }
|
||||
} as Record<string, RegistryEndpointConfig>)
|
||||
|
||||
expect(result!['openai-chat-completions'].adapterFamily).toBe('openai-compatible')
|
||||
expect(result!['openai-chat-completions'].baseUrl).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// inferAdapterFamily
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe('inferAdapterFamily', () => {
|
||||
it('catalog adapterFamily wins over endpoint default', () => {
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES, { adapterFamily: 'aihubmix' })).toBe('aihubmix')
|
||||
})
|
||||
|
||||
it('falls back to endpoint default when catalog has no adapterFamily', () => {
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES, {})).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('falls back to endpoint default when catalog is absent', () => {
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES)).toBe('anthropic')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.GOOGLE_GENERATE_CONTENT)).toBe('google')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OLLAMA_CHAT)).toBe('ollama')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OLLAMA_GENERATE)).toBe('ollama')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.JINA_RERANK)).toBe('jina-rerank')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_RESPONSES)).toBe('openai')
|
||||
})
|
||||
|
||||
it('falls back to openai-compatible for endpoints with no specific default', () => {
|
||||
// openai-chat-completions is intentionally generic — many vendors speak it
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_CHAT_COMPLETIONS)).toBe('openai-compatible')
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_IMAGE_GENERATION)).toBe('openai-compatible')
|
||||
})
|
||||
|
||||
it('accepts both RegistryEndpointConfig and RuntimeEndpointConfig shapes', () => {
|
||||
// Both schemas have adapterFamily — the function only needs to peek that
|
||||
// one field so the input type is structural.
|
||||
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_CHAT_COMPLETIONS, { adapterFamily: 'groq' })).toBe('groq')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -63,7 +63,15 @@ export { normalizeModelId } from './utils/normalize'
|
||||
|
||||
// Pure lookup and transformation utilities (no fs dependency)
|
||||
export type { ModelLookupResult, RuntimeEndpointConfig } from './registry-utils'
|
||||
export { buildRuntimeEndpointConfigs, lookupRegistryModel, lookupRegistryProvider } from './registry-utils'
|
||||
export {
|
||||
buildRuntimeEndpointConfigs,
|
||||
inferAdapterFamily,
|
||||
lookupRegistryModel,
|
||||
lookupRegistryProvider
|
||||
} from './registry-utils'
|
||||
|
||||
// Shared vendor identity regex, used by shared model helpers.
|
||||
export { VENDOR_PATTERNS } from './patterns/vendor-patterns'
|
||||
// Shared vendor identity regex — consumed by @shared capability inference
|
||||
// and @cherrystudio/ui icon routing. Single source of truth for "which
|
||||
// vendor does this raw model ID belong to".
|
||||
export type { VendorKey } from './patterns/vendor-patterns'
|
||||
export { isVendor, matchVendor, VENDOR_PATTERNS } from './patterns/vendor-patterns'
|
||||
|
||||
@@ -2,13 +2,17 @@
|
||||
* Vendor identity regex patterns — the single source of truth for
|
||||
* "which vendor does this raw model ID belong to".
|
||||
*
|
||||
* Shared by:
|
||||
* Shared across three call sites:
|
||||
* - `@shared/utils/model` — vendor check functions (`isAnthropicModel`
|
||||
* etc.) and capability inference (e.g. deciding which IDs to mark
|
||||
* `REASONING` in the schema).
|
||||
* - `@cherrystudio/ui` icon registry — vendor-level icon routing for
|
||||
* models whose ID doesn't have a dedicated SKU icon.
|
||||
* - Future callers doing vendor dispatch.
|
||||
*
|
||||
* Keeping these regex in the registry layer lets model capability
|
||||
* inference use provider-owned vendor taxonomy instead of renderer config.
|
||||
* Keeping these regex in the registry layer means both capability
|
||||
* inference and icon lookup stay in lockstep when a new vendor /
|
||||
* naming convention lands.
|
||||
*
|
||||
* Scope: **vendor identity only**. SKU-level patterns (`gpt-5.1-codex-mini`,
|
||||
* `claude-sonnet-4-6`, etc.) stay in their specific consumer modules —
|
||||
@@ -80,3 +84,27 @@ export const VENDOR_PATTERNS = {
|
||||
/** Mistral family */
|
||||
mistral: /mistral|pixtral|codestral|ministral|voxtral|devstral|mixtral|magistral/i
|
||||
} as const satisfies Record<string, RegExp>
|
||||
|
||||
export type VendorKey = keyof typeof VENDOR_PATTERNS
|
||||
|
||||
/**
|
||||
* Return the vendor slug for a normalized model ID, or `undefined` if
|
||||
* no vendor pattern matches. Iteration order is stable (key insertion
|
||||
* order) but not semantically important — patterns don't overlap.
|
||||
*/
|
||||
export function matchVendor(normalizedId: string): VendorKey | undefined {
|
||||
for (const [vendor, pattern] of Object.entries(VENDOR_PATTERNS) as [VendorKey, RegExp][]) {
|
||||
if (pattern.test(normalizedId)) return vendor
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Lightweight vendor predicate factory. Exported primarily so consumers
|
||||
* can spell the check as `isVendor('anthropic')(id)` when composing
|
||||
* higher-level logic.
|
||||
*/
|
||||
export function isVendor(vendor: VendorKey): (normalizedId: string) => boolean {
|
||||
const pattern = VENDOR_PATTERNS[vendor]
|
||||
return (id: string) => pattern.test(id)
|
||||
}
|
||||
|
||||
@@ -184,6 +184,11 @@ export class RegistryLoader {
|
||||
return this.modelById!.get(modelId) ?? this.modelByNormId!.get(normalizeModelId(modelId)) ?? null
|
||||
}
|
||||
|
||||
findProvider(providerId: string): ProviderConfig | null {
|
||||
const providers = this.loadProviders()
|
||||
return providers.find((p) => p.id === providerId) ?? null
|
||||
}
|
||||
|
||||
findOverride(providerId: string, modelId: string): ProviderModelOverride | null {
|
||||
this.loadProviderModels()
|
||||
const key = `${providerId}::${modelId}`
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
* Safe to import from browser/renderer contexts.
|
||||
*/
|
||||
|
||||
import { ENDPOINT_TYPE, type EndpointType } from './schemas/enums'
|
||||
import type { ModelConfig } from './schemas/model'
|
||||
import type { ProviderConfig, RegistryEndpointConfig } from './schemas/provider'
|
||||
import type { ProviderModelOverride } from './schemas/provider-models'
|
||||
@@ -51,6 +52,7 @@ export interface RuntimeEndpointConfig {
|
||||
baseUrl?: string
|
||||
modelsApiUrls?: { default?: string; embedding?: string; reranker?: string }
|
||||
reasoningFormatType?: string
|
||||
adapterFamily?: string
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -70,9 +72,45 @@ export function buildRuntimeEndpointConfigs(
|
||||
if (regConfig.baseUrl) config.baseUrl = regConfig.baseUrl
|
||||
if (regConfig.modelsApiUrls) config.modelsApiUrls = regConfig.modelsApiUrls
|
||||
if (regConfig.reasoningFormat?.type) config.reasoningFormatType = regConfig.reasoningFormat.type
|
||||
if (regConfig.adapterFamily) config.adapterFamily = regConfig.adapterFamily
|
||||
|
||||
if (Object.keys(config).length > 0) configs[k] = config
|
||||
}
|
||||
|
||||
return Object.keys(configs).length > 0 ? configs : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Default AI SDK adapter family per endpoint type. Used when the catalog
|
||||
* doesn't specify one and no more-specific signal (e.g. legacy provider type)
|
||||
* is available. The mapping is purely protocol-derived — any endpoint that
|
||||
* speaks anthropic-messages format needs the `anthropic` adapter, etc.
|
||||
*/
|
||||
const ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY: Partial<Record<EndpointType, string>> = {
|
||||
[ENDPOINT_TYPE.ANTHROPIC_MESSAGES]: 'anthropic',
|
||||
[ENDPOINT_TYPE.GOOGLE_GENERATE_CONTENT]: 'google',
|
||||
[ENDPOINT_TYPE.OLLAMA_CHAT]: 'ollama',
|
||||
[ENDPOINT_TYPE.OLLAMA_GENERATE]: 'ollama',
|
||||
[ENDPOINT_TYPE.JINA_RERANK]: 'jina-rerank',
|
||||
[ENDPOINT_TYPE.OPENAI_RESPONSES]: 'openai'
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the AI SDK adapter family for an endpoint. Single source of truth
|
||||
* for seeder / migrator / UI creation paths — `adapterFamily` is a derived,
|
||||
* write-time value; the runtime resolver only reads it.
|
||||
*
|
||||
* 1. Catalog `adapterFamily` wins when present (encodes vendor-specific
|
||||
* relay routing like `aihubmix` for anthropic-messages on AiHubMix).
|
||||
* 2. Otherwise, fall back to the endpoint-type default
|
||||
* (`anthropic-messages` → `anthropic`, etc.).
|
||||
* 3. Final fallback `openai-compatible` covers `openai-chat-completions`
|
||||
* and any future openai-protocol endpoint without a more specific match.
|
||||
*/
|
||||
export function inferAdapterFamily(
|
||||
endpointType: EndpointType,
|
||||
catalogConfig?: Pick<RegistryEndpointConfig, 'adapterFamily'> | Pick<RuntimeEndpointConfig, 'adapterFamily'> | null
|
||||
): string {
|
||||
if (catalogConfig?.adapterFamily) return catalogConfig.adapterFamily
|
||||
return ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY[endpointType] ?? 'openai-compatible'
|
||||
}
|
||||
|
||||
@@ -31,10 +31,8 @@ export const ApiFeaturesSchema = z.object({
|
||||
developerRole: z.boolean().default(false),
|
||||
/** Whether the provider supports service tier selection (OpenAI/Groq-specific) */
|
||||
serviceTier: z.boolean().default(false),
|
||||
/** Whether the provider supports verbosity settings (Gemini-specific) */
|
||||
verbosity: z.boolean().default(false),
|
||||
/** Whether the provider supports enable_thinking parameter */
|
||||
enableThinking: z.boolean().default(true)
|
||||
/** Whether the provider supports verbosity settings (OpenAI-specific) */
|
||||
verbosity: z.boolean().default(false)
|
||||
})
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -197,7 +195,13 @@ export const RegistryEndpointConfigSchema = z.object({
|
||||
})
|
||||
.optional(),
|
||||
/** How this endpoint type expects reasoning parameters to be formatted */
|
||||
reasoningFormat: ProviderReasoningFormatSchema.optional()
|
||||
reasoningFormat: ProviderReasoningFormatSchema.optional(),
|
||||
/**
|
||||
* AI SDK adapter family that handles this endpoint. Aligns with the IDs
|
||||
* registered in `appProviderIds`. Resolvers should prefer this over
|
||||
* heuristic id/baseUrl inference when present.
|
||||
*/
|
||||
adapterFamily: z.string().optional()
|
||||
})
|
||||
|
||||
export const ProviderConfigSchema = z
|
||||
|
||||
@@ -300,7 +300,7 @@ export function EntitySelector<T extends EntityItemBase>(props: EntitySelectorPr
|
||||
userOnOpenAutoFocus?.(event)
|
||||
}}
|
||||
className={cn(
|
||||
'flex max-h-[var(--radix-popover-content-available-height)] flex-col overflow-hidden rounded-2xs border-border/60 bg-popover p-0 shadow-lg',
|
||||
'flex max-h-[var(--radix-popover-content-available-height)] flex-col overflow-hidden rounded-lg border-border/60 bg-popover p-0 shadow-lg',
|
||||
userPopoverClassName,
|
||||
className
|
||||
)}>
|
||||
|
||||
@@ -2,6 +2,13 @@ import { MODEL_ICON_CATALOG, type ModelIconKey } from './models/catalog'
|
||||
import { PROVIDER_ICON_CATALOG, type ProviderIconKey } from './providers/catalog'
|
||||
import type { CompoundIcon } from './types'
|
||||
|
||||
// NOTE: the vendor-level regex below duplicate `@cherrystudio/provider-registry`'s
|
||||
// `VENDOR_PATTERNS` (anthropic, gemini, gemma, grok, doubao, hunyuan, kimi, zhipu,
|
||||
// mimo, ling, qwen). Kept in sync manually until UI's build surface lets us
|
||||
// import from `@cherrystudio/provider-registry` directly. When adding / tweaking
|
||||
// a vendor pattern, update BOTH places — or, better, fix the UI → registry import
|
||||
// story and swap these inline regex for `VENDOR_PATTERNS.<vendor>`.
|
||||
|
||||
/**
|
||||
* Model ID regex patterns mapped to MODEL_ICON_CATALOG keys.
|
||||
* Order matters: more specific patterns must come before general ones.
|
||||
|
||||
1238
pnpm-lock.yaml
generated
1238
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,14 @@ const workspaceConfigPath = path.join(__dirname, '..', 'pnpm-workspace.yaml')
|
||||
// if you want to add new prebuild binaries packages with different architectures, you can add them here
|
||||
// please add to allX64 and allArm64 from pnpm-lock.yaml
|
||||
const packages = [
|
||||
'@anthropic-ai/claude-agent-sdk-darwin-arm64',
|
||||
'@anthropic-ai/claude-agent-sdk-darwin-x64',
|
||||
'@anthropic-ai/claude-agent-sdk-linux-arm64',
|
||||
'@anthropic-ai/claude-agent-sdk-linux-arm64-musl',
|
||||
'@anthropic-ai/claude-agent-sdk-linux-x64',
|
||||
'@anthropic-ai/claude-agent-sdk-linux-x64-musl',
|
||||
'@anthropic-ai/claude-agent-sdk-win32-arm64',
|
||||
'@anthropic-ai/claude-agent-sdk-win32-x64',
|
||||
'@img/sharp-darwin-arm64',
|
||||
'@img/sharp-darwin-x64',
|
||||
'@img/sharp-libvips-darwin-arm64',
|
||||
@@ -132,7 +140,7 @@ exports.default = async function (context) {
|
||||
}
|
||||
return f !== `${arch}-${platform}`
|
||||
})
|
||||
.map((f) => '!node_modules/@anthropic-ai/claude-agent-sdk/vendor/ripgrep/' + f + '/**')
|
||||
.map((f) => '!node_modules/@cherrystudio/ripgrep/vendor/ripgrep/' + f + '/**')
|
||||
|
||||
// Exclude rtk binaries for other platform-arch combinations
|
||||
const currentPlatformKey = `${platform}-${arch}`
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import { builtinModules } from 'node:module'
|
||||
|
||||
import { resolve } from 'path'
|
||||
import { build as viteBuild, type Plugin } from 'vite'
|
||||
|
||||
interface BuildProxyBootstrapPluginOptions {
|
||||
dependencies: string[]
|
||||
isProd: boolean
|
||||
rootDir: string
|
||||
}
|
||||
|
||||
export const buildProxyBootstrapPlugin = ({
|
||||
dependencies,
|
||||
isProd,
|
||||
rootDir
|
||||
}: BuildProxyBootstrapPluginOptions): Plugin => {
|
||||
return {
|
||||
name: 'cherry-build-proxy-bootstrap',
|
||||
apply: 'build',
|
||||
async closeBundle() {
|
||||
await viteBuild({
|
||||
configFile: false,
|
||||
publicDir: false,
|
||||
resolve: {
|
||||
mainFields: ['module', 'jsnext:main', 'jsnext'],
|
||||
conditions: ['node']
|
||||
},
|
||||
build: {
|
||||
outDir: resolve(rootDir, 'out/proxy'),
|
||||
target: 'node22',
|
||||
minify: false,
|
||||
reportCompressedSize: false,
|
||||
copyPublicDir: false,
|
||||
lib: {
|
||||
entry: resolve(rootDir, 'src/main/services/proxy/bootstrap.ts'),
|
||||
formats: ['cjs'],
|
||||
fileName: () => 'index.js'
|
||||
},
|
||||
rollupOptions: {
|
||||
external: [
|
||||
'electron',
|
||||
/^electron\/.+/,
|
||||
...builtinModules.flatMap((moduleName) => [moduleName, `node:${moduleName}`]),
|
||||
...dependencies
|
||||
]
|
||||
}
|
||||
},
|
||||
esbuild: isProd ? { legalComments: 'none' } : {}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -35,10 +35,8 @@ const { mockSyncBuiltinSkill } = vi.hoisted(() => ({
|
||||
mockSyncBuiltinSkill: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../services/agents/skills/SkillService', () => ({
|
||||
skillService: {
|
||||
syncBuiltinSkill: mockSyncBuiltinSkill
|
||||
}
|
||||
vi.mock('@main/ai/skills/SkillService', () => ({
|
||||
skillService: { syncBuiltinSkill: mockSyncBuiltinSkill }
|
||||
}))
|
||||
|
||||
// Matches the stub in tests/main.setup.ts → mockApplicationFactory().getPath
|
||||
|
||||
591
src/main/ai/AiService.ts
Normal file
591
src/main/ai/AiService.ts
Normal file
@@ -0,0 +1,591 @@
|
||||
import { embedMany as aiCoreEmbedMany, generateImage as aiCoreGenerateImage } from '@cherrystudio/ai-core'
|
||||
import { assistantDataService } from '@data/services/AssistantService'
|
||||
import type { PersonGeneration } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { application } from '@main/core/application'
|
||||
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
|
||||
import { messageService } from '@main/data/services/MessageService'
|
||||
import { modelService } from '@main/data/services/ModelService'
|
||||
import { providerService } from '@main/data/services/ProviderService'
|
||||
import { type TranslateOpenRequest, translateService } from '@main/services/translate/translateService'
|
||||
import { downloadImageAsBase64 } from '@main/utils/downloadAsBase64'
|
||||
import { applyApprovalDecisions } from '@shared/ai/transport'
|
||||
import { type Assistant } from '@shared/data/types/assistant'
|
||||
import type { FileEntry } from '@shared/data/types/file/fileEntry'
|
||||
import { type Model, parseUniqueModelId } from '@shared/data/types/model'
|
||||
import type { Base64String } from '@shared/file/types/common'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { isEmbeddingModel } from '@shared/utils/model'
|
||||
import {
|
||||
type EmbeddingModelUsage,
|
||||
isToolUIPart,
|
||||
type LanguageModelUsage,
|
||||
type ModelMessage,
|
||||
type UIMessageChunk
|
||||
} from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { isAgentSessionTopic } from './agentSession/topic'
|
||||
import { resolveUIMessageFileUrls } from './messages/messageConverter'
|
||||
import { listModels as listModelsFromProvider } from './provider/listModels'
|
||||
import { Agent } from './runtime/aiSdk/Agent'
|
||||
import type { AgentLoopHooks } from './runtime/aiSdk/loop'
|
||||
import { mergeUsage, ZERO_USAGE } from './runtime/aiSdk/observers/usage'
|
||||
import { buildAgentParams } from './runtime/aiSdk/params/buildAgentParams'
|
||||
import type { RequestFeature } from './runtime/aiSdk/params/feature'
|
||||
import { WebContentsListener } from './streamManager/listeners/WebContentsListener'
|
||||
import { registerBuiltinTools } from './tools/adapters/aiSdk/builtin'
|
||||
import type { AppProviderSettingsMap } from './types'
|
||||
import type { AiBaseRequest, AiStreamRequest, AiTransportOptions, ListModelsRequest } from './types/requests'
|
||||
import { buildImageProviderOptions, normalizeAspectRatio } from './utils/imageOptions'
|
||||
|
||||
const logger = loggerService.withContext('AiService')
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────────
|
||||
|
||||
/** In-process variant of `AiTransportOptions` — adds `signal`, which is not IPC-serialisable. */
|
||||
export interface AiRequestOptions extends AiTransportOptions {
|
||||
/** In-process only. Renderer payloads use `AiTransportOptions` (no signal). */
|
||||
signal?: AbortSignal
|
||||
}
|
||||
|
||||
/** Widens `requestOptions` to accept the in-process shape on `AiService.*` method signatures. */
|
||||
export type AsInProcess<T extends AiBaseRequest> = Omit<T, 'requestOptions'> & {
|
||||
requestOptions?: AiRequestOptions
|
||||
}
|
||||
|
||||
/** Non-streaming text generation request — pure transport data. */
|
||||
export interface AiGenerateRequest extends AiBaseRequest {
|
||||
system?: string
|
||||
prompt?: string
|
||||
messages?: ModelMessage[]
|
||||
}
|
||||
|
||||
// ── SDK extensions ─────────────────────────────────────────────────
|
||||
|
||||
/** Result of non-streaming text generation. */
|
||||
export interface AiGenerateResult {
|
||||
text: string
|
||||
usage?: LanguageModelUsage
|
||||
}
|
||||
|
||||
/** Image generation request. */
|
||||
export interface AiImageRequest extends AiBaseRequest {
|
||||
prompt: string
|
||||
/** Input images for editing (base64 data URLs or URLs). If provided, uses edit mode. */
|
||||
inputImages?: string[]
|
||||
/** Mask for inpainting (only with inputImages). */
|
||||
mask?: string
|
||||
n?: number
|
||||
size?: string
|
||||
negativePrompt?: string
|
||||
seed?: number
|
||||
quality?: string
|
||||
numInferenceSteps?: number
|
||||
guidanceScale?: number
|
||||
promptEnhancement?: boolean
|
||||
personGeneration?: PersonGeneration
|
||||
aspectRatio?: string
|
||||
background?: string
|
||||
moderation?: string
|
||||
style?: string
|
||||
/** Vendor-specific image params keyed by provider id; mapped to AI SDK provider options in main. */
|
||||
providerOptions?: Record<string, Record<string, unknown>>
|
||||
}
|
||||
|
||||
/** Image generation result — persisted file entries (main writes the bytes). */
|
||||
export interface AiImageResult {
|
||||
files: FileEntry[]
|
||||
}
|
||||
|
||||
/** Embedding request. */
|
||||
export interface AiEmbedRequest extends AiBaseRequest {
|
||||
values: string[]
|
||||
}
|
||||
|
||||
/** Embedding result. */
|
||||
export interface AiEmbedResult {
|
||||
embeddings: number[][]
|
||||
usage?: EmbeddingModelUsage
|
||||
}
|
||||
|
||||
/** Validates the `Ai_ToolApproval_Respond` IPC payload at the renderer boundary. */
|
||||
const ToolApprovalRespondSchema = z.object({
|
||||
approvalId: z.string().min(1),
|
||||
approved: z.boolean(),
|
||||
reason: z.string().optional(),
|
||||
updatedInput: z.record(z.string(), z.unknown()).optional(),
|
||||
topicId: z.string().optional(),
|
||||
anchorId: z.string().optional()
|
||||
})
|
||||
|
||||
// ── Service ────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Lifecycle AI service. See `docs/references/ai/core-architecture.md`.
|
||||
*
|
||||
* DO NOT mirror `@DependsOn(['AiService'])` on AiStreamManager —
|
||||
* `runExecutionLoop` looks AiService up at runtime, and every `send()`
|
||||
* caller routes through AiService first.
|
||||
*/
|
||||
@Injectable('AiService')
|
||||
@ServicePhase(Phase.WhenReady)
|
||||
@DependsOn(['McpRuntimeService', 'McpCatalogService', 'AiStreamManager'])
|
||||
export class AiService extends BaseService {
|
||||
// Per-request AbortControllers for `Ai_GenerateImage`, paired with the
|
||||
// `Ai_AbortImage` channel. Key is the renderer-generated requestId
|
||||
// (see `src/preload/index.ts`). Entries are self-cleaning via the
|
||||
// handler's `finally` block; abort on an unknown id is a no-op.
|
||||
// TODO(abort-registry): collapse with MCP/stream/LAN registries once
|
||||
// the shared `ipcHandleWithAbort` helper lands.
|
||||
private readonly imageRequests = new Map<string, AbortController>()
|
||||
|
||||
protected async onInit(): Promise<void> {
|
||||
registerBuiltinTools()
|
||||
this.registerIpcHandlers()
|
||||
logger.info('AiService initialized')
|
||||
}
|
||||
|
||||
private registerIpcHandlers(): void {
|
||||
this.ipcHandle(IpcChannel.Ai_GenerateText, async (_, request: AiGenerateRequest) => {
|
||||
return this.generateText(request)
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_CheckModel, async (_, request: AiBaseRequest & { timeout?: number }) => {
|
||||
return this.checkModel(request)
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_EmbedMany, async (_, request: AiEmbedRequest) => {
|
||||
return this.embedMany(request)
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_GenerateImage, async (_, request: { requestId: string; payload: AiImageRequest }) => {
|
||||
const { requestId, payload } = request
|
||||
const controller = new AbortController()
|
||||
this.imageRequests.set(requestId, controller)
|
||||
try {
|
||||
return await this.generateImage({
|
||||
...payload,
|
||||
requestOptions: { ...payload.requestOptions, signal: controller.signal }
|
||||
})
|
||||
} finally {
|
||||
this.imageRequests.delete(requestId)
|
||||
}
|
||||
})
|
||||
|
||||
this.ipcOn(IpcChannel.Ai_AbortImage, (_, request: { requestId: string }) => {
|
||||
this.imageRequests.get(request.requestId)?.abort()
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_ListModels, async (_, request: ListModelsRequest) => {
|
||||
return this.listModels(request)
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_Translate_Open, async (event, request: TranslateOpenRequest) => {
|
||||
return translateService.open(event.sender, request)
|
||||
})
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_ToolApproval_Respond, async (event, rawPayload: unknown): Promise<{ ok: boolean }> => {
|
||||
// Validate the renderer payload at the IPC boundary before any registry dispatch or DB read.
|
||||
const parsed = ToolApprovalRespondSchema.safeParse(rawPayload)
|
||||
if (!parsed.success) {
|
||||
logger.warn('Tool-approval response rejected: invalid payload', { issues: parsed.error.issues })
|
||||
return { ok: false }
|
||||
}
|
||||
const payload = parsed.data
|
||||
|
||||
// Claude-Agent fast-path: live registry entry unblocks `canUseTool`.
|
||||
const dispatched = application.get('AgentSessionRuntimeService').respondToolApproval(payload.approvalId, {
|
||||
approved: payload.approved,
|
||||
reason: payload.reason,
|
||||
updatedInput: payload.updatedInput
|
||||
})
|
||||
if (dispatched) return { ok: true }
|
||||
|
||||
// MCP path: write decisions to DB, then dispatch continue-conversation when nothing is pending.
|
||||
if (!payload.topicId || !payload.anchorId) {
|
||||
logger.warn('Tool-approval response had no live registry entry and no anchor context', {
|
||||
approvalId: payload.approvalId
|
||||
})
|
||||
return { ok: false }
|
||||
}
|
||||
|
||||
// Main is the single authority for the approval mutation: the
|
||||
// renderer no longer PATCHes (it sourced parts from a DB projection
|
||||
// that didn't carry the overlay-only `approval-requested` part and
|
||||
// raced/overwrote the persisted row). The decision is carried
|
||||
// explicitly in the IPC payload; apply it here to the DB-authoritative
|
||||
// parts (the original stream's terminal persistence wrote the
|
||||
// `approval-requested` part onto this row) and persist.
|
||||
const decision = {
|
||||
approvalId: payload.approvalId,
|
||||
approved: payload.approved,
|
||||
...(payload.reason !== undefined && { reason: payload.reason })
|
||||
}
|
||||
// A stale click on a deleted message must resolve through the documented
|
||||
// result shape, not throw out of the handler (getById rejects when the
|
||||
// anchor is missing), consistent with the no-context branch above.
|
||||
let anchor: Awaited<ReturnType<typeof messageService.getById>>
|
||||
try {
|
||||
anchor = await messageService.getById(payload.anchorId)
|
||||
} catch {
|
||||
logger.warn('Tool-approval response anchor is missing or deleted', {
|
||||
approvalId: payload.approvalId,
|
||||
anchorId: payload.anchorId
|
||||
})
|
||||
return { ok: false }
|
||||
}
|
||||
const beforeParts = anchor.data.parts ?? []
|
||||
const targetPresent = beforeParts.some(
|
||||
(p) => isToolUIPart(p) && p.state === 'approval-requested' && p.approval?.id === decision.approvalId
|
||||
)
|
||||
const afterParts = applyApprovalDecisions(beforeParts, [decision])
|
||||
// Only write parts when this approval is present on the DB row.
|
||||
// `applyApprovalDecisions` always returns a fresh array, so writing
|
||||
// unconditionally would overwrite real (or not-yet-persisted) parts
|
||||
// with an unchanged set. When the part is overlay-only (persist not
|
||||
// landed yet), the continue dispatch below carries the decision and
|
||||
// the continue provider applies it authoritatively where it reads parts.
|
||||
if (targetPresent) {
|
||||
await messageService.update(payload.anchorId, { data: { parts: afterParts } })
|
||||
}
|
||||
|
||||
// Only resume once every approval on this turn is decided — a turn
|
||||
// can request several tools at once; the not-yet-decided ones keep
|
||||
// their cards.
|
||||
const anyStillPending = afterParts.some((p) => isToolUIPart(p) && p.state === 'approval-requested')
|
||||
if (anyStillPending) {
|
||||
return { ok: true }
|
||||
}
|
||||
|
||||
const aiStreamManager = application.get('AiStreamManager')
|
||||
const subscriber = new WebContentsListener(event.sender, payload.topicId)
|
||||
await aiStreamManager.dispatch(subscriber, {
|
||||
trigger: 'continue-conversation',
|
||||
topicId: payload.topicId,
|
||||
parentAnchorId: payload.anchorId,
|
||||
// Idempotent against the conditional write above; safety net when the part wasn't on the row.
|
||||
approvalDecisions: [decision]
|
||||
})
|
||||
return { ok: true }
|
||||
})
|
||||
}
|
||||
|
||||
// ── Streaming chat (agent.stream) ──
|
||||
|
||||
/**
|
||||
* Raw `UIMessageChunk` stream from `Agent.stream`. Caller (usually
|
||||
* `AiStreamManager`) owns read/multicast/accumulation/terminal dispatch.
|
||||
* Pre-stream errors reject the Promise; mid-stream errors come through
|
||||
* the stream itself.
|
||||
*/
|
||||
async streamText(
|
||||
request: AsInProcess<AiStreamRequest>,
|
||||
extraFeatures: readonly RequestFeature[] = []
|
||||
): Promise<ReadableStream<UIMessageChunk>> {
|
||||
logger.info('streamText started', { chatId: request.chatId })
|
||||
const signal = request.requestOptions?.signal
|
||||
if (!signal) {
|
||||
throw new Error('streamText requires requestOptions.signal — no AbortController was attached by the caller')
|
||||
}
|
||||
|
||||
if (request.runtime?.kind === 'agent-session') {
|
||||
return application.get('AgentSessionRuntimeService').openTurnStream({
|
||||
sessionId: request.runtime.sessionId,
|
||||
turnId: request.runtime.turnId,
|
||||
signal
|
||||
})
|
||||
}
|
||||
|
||||
if (isAgentSessionTopic(request.chatId)) {
|
||||
throw new Error(`Agent session stream ${request.chatId} requires an agent-session runtime request`)
|
||||
}
|
||||
|
||||
const { sdkConfig, tools, plugins, system, options, model, hookParts } = await this.buildAgentParamsFor(
|
||||
request,
|
||||
signal,
|
||||
extraFeatures
|
||||
)
|
||||
|
||||
const preparedMessages = await resolveUIMessageFileUrls(request.messages ?? [])
|
||||
|
||||
const agent = new Agent({
|
||||
providerId: sdkConfig.providerId,
|
||||
providerSettings: sdkConfig.providerSettings,
|
||||
modelId: sdkConfig.modelId,
|
||||
messageId: request.messageId,
|
||||
plugins,
|
||||
tools,
|
||||
system,
|
||||
options,
|
||||
hookParts: [this.analyticsHookPart(model), ...hookParts]
|
||||
})
|
||||
|
||||
return agent.stream(preparedMessages, signal)
|
||||
}
|
||||
|
||||
private analyticsHookPart(model: Model): Partial<AgentLoopHooks> {
|
||||
let total: LanguageModelUsage = ZERO_USAGE
|
||||
return {
|
||||
onStepFinish: (step) => {
|
||||
if (step.usage) total = mergeUsage(total, step.usage)
|
||||
},
|
||||
onFinish: () => this.trackUsage(model, total)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Non-streaming text generation (agent.generate) ──
|
||||
|
||||
async generateText(
|
||||
request: AsInProcess<AiGenerateRequest>,
|
||||
extraFeatures: readonly RequestFeature[] = []
|
||||
): Promise<AiGenerateResult> {
|
||||
logger.info('generateText started', { assistantId: request.assistantId })
|
||||
const signal = request.requestOptions?.signal
|
||||
|
||||
const { sdkConfig, tools, plugins, system, options, model, hookParts } = await this.buildAgentParamsFor(
|
||||
request,
|
||||
signal,
|
||||
extraFeatures
|
||||
)
|
||||
|
||||
const agent = new Agent({
|
||||
providerId: sdkConfig.providerId,
|
||||
providerSettings: sdkConfig.providerSettings,
|
||||
modelId: sdkConfig.modelId,
|
||||
plugins,
|
||||
tools,
|
||||
system: request.system ?? system,
|
||||
options,
|
||||
hookParts: [this.analyticsHookPart(model), ...hookParts]
|
||||
})
|
||||
|
||||
// prompt and messages are mutually exclusive in AI SDK; preserve that.
|
||||
return agent.generate(request.prompt ? { prompt: request.prompt } : { messages: request.messages ?? [] }, signal)
|
||||
}
|
||||
|
||||
// ── Image generation ──
|
||||
|
||||
async generateImage(request: AsInProcess<AiImageRequest>): Promise<AiImageResult> {
|
||||
logger.info('generateImage started', { assistantId: request.assistantId, uniqueModelId: request.uniqueModelId })
|
||||
const signal = request.requestOptions?.signal
|
||||
|
||||
const { sdkConfig } = await this.buildAgentParamsFor(request, signal)
|
||||
|
||||
const promptParam = request.inputImages
|
||||
? { text: request.prompt, images: request.inputImages, ...(request.mask && { mask: request.mask }) }
|
||||
: request.prompt
|
||||
|
||||
// Map the canonical painting params onto each vendor's real image-API field
|
||||
// names (negative_prompt / seed / imageConfig / …). AI SDK image models
|
||||
// spread `providerOptions[<providerId>]` into the request body, so this is
|
||||
// how negativePrompt/seed/steps/guidance/aspectRatio actually reach vendors.
|
||||
const imageProviderOptions = buildImageProviderOptions(sdkConfig.providerId, {
|
||||
negativePrompt: request.negativePrompt,
|
||||
seed: request.seed !== undefined ? String(request.seed) : undefined,
|
||||
numInferenceSteps: request.numInferenceSteps,
|
||||
guidanceScale: request.guidanceScale,
|
||||
promptEnhancement: request.promptEnhancement,
|
||||
personGeneration: request.personGeneration,
|
||||
quality: request.quality,
|
||||
aspectRatio: request.aspectRatio,
|
||||
imageSize: request.size,
|
||||
providerOptions: request.providerOptions,
|
||||
background: request.background,
|
||||
moderation: request.moderation,
|
||||
style: request.style
|
||||
})
|
||||
const aspectRatio = normalizeAspectRatio(request.aspectRatio)
|
||||
|
||||
const imageParams = {
|
||||
model: sdkConfig.modelId,
|
||||
prompt: promptParam,
|
||||
n: request.n ?? 1,
|
||||
// Client-side default: when the caller omits `size`, fall back to 1024x1024
|
||||
// rather than letting the server pick its own default. Dropping this fallback
|
||||
// (to truly let the server choose) is a behavior decision, not done here.
|
||||
size: (request.size ?? '1024x1024') as `${number}x${number}`,
|
||||
...(request.negativePrompt ? { negativePrompt: request.negativePrompt } : {}),
|
||||
...(request.seed !== undefined ? { seed: request.seed } : {}),
|
||||
...(request.quality ? { quality: request.quality } : {}),
|
||||
...(request.numInferenceSteps !== undefined ? { numInferenceSteps: request.numInferenceSteps } : {}),
|
||||
...(request.guidanceScale !== undefined ? { guidanceScale: request.guidanceScale } : {}),
|
||||
...(request.promptEnhancement !== undefined ? { promptEnhancement: request.promptEnhancement } : {}),
|
||||
...(aspectRatio ? { aspectRatio: aspectRatio as `${number}:${number}` } : {}),
|
||||
...(Object.keys(imageProviderOptions).length > 0 ? { providerOptions: imageProviderOptions } : {}),
|
||||
...(signal ? { abortSignal: signal } : {}),
|
||||
experimental_download: async (downloads) => {
|
||||
return Promise.all(
|
||||
downloads.map(async ({ url }) => {
|
||||
if (signal?.aborted) return null
|
||||
const downloaded = await downloadImageAsBase64(url.toString())
|
||||
if (signal?.aborted) return null
|
||||
if (!downloaded) return null
|
||||
return {
|
||||
data: Buffer.from(downloaded.data, 'base64'),
|
||||
mediaType: downloaded.media_type
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const result = await aiCoreGenerateImage<AppProviderSettingsMap>(
|
||||
sdkConfig.providerId,
|
||||
sdkConfig.providerSettings,
|
||||
imageParams
|
||||
)
|
||||
|
||||
const dataUrls: Base64String[] = []
|
||||
let filteredCount = 0
|
||||
for (const image of result.images ?? []) {
|
||||
if (image.base64) {
|
||||
dataUrls.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
|
||||
continue
|
||||
}
|
||||
|
||||
filteredCount += 1
|
||||
}
|
||||
|
||||
if (filteredCount > 0) {
|
||||
logger.warn('Filtered invalid generated images', {
|
||||
uniqueModelId: request.uniqueModelId,
|
||||
providerId: sdkConfig.providerId,
|
||||
modelId: sdkConfig.modelId,
|
||||
filteredCount
|
||||
})
|
||||
}
|
||||
const fileManager = application.get('FileManager')
|
||||
const files = await Promise.all(dataUrls.map((data) => fileManager.createInternalEntry({ source: 'base64', data })))
|
||||
|
||||
return { files }
|
||||
}
|
||||
|
||||
// ── Embedding ──
|
||||
|
||||
async embedMany(request: AsInProcess<AiEmbedRequest>): Promise<AiEmbedResult> {
|
||||
logger.info('embedMany started', { assistantId: request.assistantId, count: request.values.length })
|
||||
const signal = request.requestOptions?.signal
|
||||
|
||||
const { sdkConfig, model } = await this.buildAgentParamsFor(request, signal)
|
||||
|
||||
const result = await aiCoreEmbedMany<AppProviderSettingsMap>(sdkConfig.providerId, sdkConfig.providerSettings, {
|
||||
model: sdkConfig.modelId,
|
||||
values: request.values,
|
||||
...(signal ? { abortSignal: signal } : {})
|
||||
})
|
||||
|
||||
this.trackUsage(model, { inputTokens: result.usage?.tokens ?? 0, outputTokens: 0 })
|
||||
return { embeddings: result.embeddings, usage: result.usage }
|
||||
}
|
||||
|
||||
// ── Model listing ──
|
||||
async listModels(request: ListModelsRequest): Promise<Partial<Model>[]> {
|
||||
let providerId = request.providerId
|
||||
if (!providerId && request.assistantId) {
|
||||
const assistant = await assistantDataService.getById(request.assistantId).catch(() => undefined)
|
||||
if (assistant?.modelId) {
|
||||
providerId = parseUniqueModelId(assistant.modelId).providerId
|
||||
}
|
||||
}
|
||||
if (!providerId) {
|
||||
throw new Error('Cannot resolve providerId: not in request and assistant has no model')
|
||||
}
|
||||
const provider = await providerService.getByProviderId(providerId)
|
||||
return listModelsFromProvider(provider, undefined, { throwOnError: request.throwOnError })
|
||||
}
|
||||
|
||||
// ── API validation ──
|
||||
|
||||
/** Dispatches to `embedMany` for embedding models, `generateText` otherwise. */
|
||||
async checkModel(request: AiBaseRequest & { timeout?: number }): Promise<{ latency: number }> {
|
||||
const { model } = await this.getProviderAndModel(request)
|
||||
const start = performance.now()
|
||||
const timeout = request.timeout ?? 15000
|
||||
|
||||
// AbortController on timeout so the HTTP work cancels too (otherwise tokens keep burning).
|
||||
const controller = new AbortController()
|
||||
let timeoutHandle: ReturnType<typeof setTimeout> | undefined
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timeoutHandle = setTimeout(() => {
|
||||
controller.abort(new Error('Check model timeout'))
|
||||
reject(new Error('Check model timeout'))
|
||||
}, timeout)
|
||||
})
|
||||
|
||||
const probeRequest = {
|
||||
...request,
|
||||
requestOptions: { ...request.requestOptions, signal: controller.signal }
|
||||
}
|
||||
const probe = isEmbeddingModel(model)
|
||||
? this.embedMany({ ...probeRequest, values: ['test'] })
|
||||
: this.generateText({ ...probeRequest, system: 'test', prompt: 'hi' })
|
||||
|
||||
try {
|
||||
await Promise.race([probe, timeoutPromise])
|
||||
return { latency: performance.now() - start }
|
||||
} finally {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared agent parameter resolution ──
|
||||
|
||||
private async buildAgentParamsFor(
|
||||
request: AsInProcess<AiBaseRequest> & { chatId?: string },
|
||||
signal: AbortSignal | undefined,
|
||||
extraFeatures: readonly RequestFeature[] = []
|
||||
) {
|
||||
const { provider, model, assistant } = await this.getProviderAndModel(request)
|
||||
const built = await buildAgentParams({ request, signal, provider, model, assistant, extraFeatures })
|
||||
return { ...built, provider, model, assistant }
|
||||
}
|
||||
|
||||
// ── Token usage tracking ──
|
||||
|
||||
private trackUsage(model: Model, usage?: { inputTokens?: number; outputTokens?: number }): void {
|
||||
if (!usage || !model.providerId || !model.apiModelId) return
|
||||
const inputTokens = usage.inputTokens ?? 0
|
||||
const outputTokens = usage.outputTokens ?? 0
|
||||
if (inputTokens === 0 && outputTokens === 0) return
|
||||
|
||||
try {
|
||||
const analyticsService = application.get('AnalyticsService')
|
||||
analyticsService.trackTokenUsage({
|
||||
provider: model.providerId,
|
||||
model: model.apiModelId ?? model.id,
|
||||
input_tokens: inputTokens,
|
||||
output_tokens: outputTokens
|
||||
})
|
||||
} catch {
|
||||
// AnalyticsService may not be activated (data collection disabled)
|
||||
}
|
||||
}
|
||||
|
||||
/** Priority: explicit `uniqueModelId` > `assistant.modelId`. */
|
||||
private async getProviderAndModel(request: AiBaseRequest & { chatId?: string }) {
|
||||
let assistant: Assistant | undefined
|
||||
if (request.assistantId) {
|
||||
assistant = await assistantDataService.getById(request.assistantId).catch(() => undefined)
|
||||
}
|
||||
|
||||
let providerId: string | undefined
|
||||
let modelId: string | undefined
|
||||
if (request.uniqueModelId) {
|
||||
const parsed = parseUniqueModelId(request.uniqueModelId)
|
||||
providerId = parsed.providerId
|
||||
modelId = parsed.modelId
|
||||
} else if (assistant?.modelId) {
|
||||
const parsed = parseUniqueModelId(assistant.modelId)
|
||||
providerId = parsed.providerId
|
||||
modelId = parsed.modelId
|
||||
}
|
||||
if (!providerId) throw new Error('Cannot resolve providerId: not in request and assistant has no model')
|
||||
if (!modelId) throw new Error('Cannot resolve modelId: not in request and assistant has no model')
|
||||
|
||||
const provider = await providerService.getByProviderId(providerId)
|
||||
const model = await modelService.getByKey(providerId, modelId)
|
||||
|
||||
return { provider, model, assistant }
|
||||
}
|
||||
}
|
||||
418
src/main/ai/__tests__/AiService.test.ts
Normal file
418
src/main/ai/__tests__/AiService.test.ts
Normal file
@@ -0,0 +1,418 @@
|
||||
import { BaseService } from '@main/core/lifecycle/BaseService'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { ipcMain } from 'electron'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mockGenerateImage = vi.fn()
|
||||
const mockDownloadImageAsBase64 = vi.fn()
|
||||
const mockApplicationGet = vi.fn()
|
||||
|
||||
vi.mock('@main/core/application', () => ({
|
||||
application: {
|
||||
get: mockApplicationGet
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@main/utils/downloadAsBase64', () => ({
|
||||
downloadImageAsBase64: (...args: unknown[]) => mockDownloadImageAsBase64(...args)
|
||||
}))
|
||||
|
||||
vi.mock('@cherrystudio/ai-core', () => ({
|
||||
createAgent: vi.fn(),
|
||||
embedMany: vi.fn(),
|
||||
generateImage: (...args: unknown[]) => mockGenerateImage(...args)
|
||||
}))
|
||||
|
||||
const { AiService } = await import('../AiService')
|
||||
const { messageService } = await import('@main/data/services/MessageService')
|
||||
|
||||
/**
|
||||
* Instantiate `AiService` directly (without going through the lifecycle
|
||||
* container) so unit tests can drive its methods in isolation.
|
||||
*/
|
||||
function createService(): InstanceType<typeof AiService> {
|
||||
BaseService.resetInstances()
|
||||
return new (AiService as any)()
|
||||
}
|
||||
|
||||
describe('AiService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('routes agent-session runtime requests directly to the runtime service', async () => {
|
||||
const service = createService()
|
||||
const stream = new ReadableStream()
|
||||
const openTurnStream = vi.fn(() => stream)
|
||||
mockApplicationGet.mockReturnValue({ openTurnStream })
|
||||
|
||||
await expect(
|
||||
service.streamText({
|
||||
chatId: 'agent-session:session-1',
|
||||
trigger: 'submit-message',
|
||||
runtime: { kind: 'agent-session', sessionId: 'session-1', turnId: 'turn-1' },
|
||||
requestOptions: { signal: new AbortController().signal }
|
||||
} as any)
|
||||
).resolves.toBe(stream)
|
||||
|
||||
expect(mockApplicationGet).toHaveBeenCalledWith('AgentSessionRuntimeService')
|
||||
expect(openTurnStream).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
turnId: 'turn-1',
|
||||
signal: expect.any(AbortSignal)
|
||||
})
|
||||
})
|
||||
|
||||
it('rejects agent-session streams that do not carry a runtime request', async () => {
|
||||
const service = createService()
|
||||
const buildAgentParamsFor = vi.spyOn(service as any, 'buildAgentParamsFor')
|
||||
|
||||
await expect(
|
||||
service.streamText({
|
||||
chatId: 'agent-session:session-1',
|
||||
trigger: 'submit-message',
|
||||
requestOptions: { signal: new AbortController().signal }
|
||||
} as any)
|
||||
).rejects.toThrow('requires an agent-session runtime request')
|
||||
|
||||
expect(buildAgentParamsFor).not.toHaveBeenCalled()
|
||||
expect(mockApplicationGet).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('normalizes base64 and url images from ai-core generateImage', async () => {
|
||||
const service = createService()
|
||||
vi.spyOn(service as never, 'buildAgentParamsFor').mockResolvedValue({
|
||||
sdkConfig: {
|
||||
providerId: 'test-provider',
|
||||
providerSettings: {},
|
||||
modelId: 'test-model'
|
||||
}
|
||||
} as never)
|
||||
|
||||
mockGenerateImage.mockResolvedValue({
|
||||
images: [{ base64: 'abc123', mediaType: 'image/png' }, { nonsense: true }],
|
||||
providerMetadata: {
|
||||
testProvider: {
|
||||
images: [{ url: 'https://example.com/image.png' }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mockDownloadImageAsBase64.mockResolvedValue({
|
||||
data: 'url-base64',
|
||||
media_type: 'image/jpeg'
|
||||
})
|
||||
|
||||
const fileEntry = { id: 'file-1', origin: 'internal', ext: 'png', name: 'img', size: 3, createdAt: 0 }
|
||||
const createInternalEntry = vi.fn().mockResolvedValue(fileEntry)
|
||||
mockApplicationGet.mockImplementation((name: string) =>
|
||||
name === 'FileManager' ? { createInternalEntry } : undefined
|
||||
)
|
||||
|
||||
const result = await service.generateImage({
|
||||
uniqueModelId: 'test-provider::test-model',
|
||||
prompt: 'draw a cat',
|
||||
n: 2,
|
||||
size: '1024x1024',
|
||||
negativePrompt: 'blurry',
|
||||
seed: 7,
|
||||
quality: 'high',
|
||||
numInferenceSteps: 30,
|
||||
guidanceScale: 4.5,
|
||||
promptEnhancement: true,
|
||||
requestOptions: { signal: new AbortController().signal }
|
||||
})
|
||||
|
||||
expect(mockGenerateImage).toHaveBeenCalledWith(
|
||||
'test-provider',
|
||||
{},
|
||||
expect.objectContaining({
|
||||
model: 'test-model',
|
||||
prompt: 'draw a cat',
|
||||
n: 2,
|
||||
size: '1024x1024',
|
||||
negativePrompt: 'blurry',
|
||||
seed: 7,
|
||||
quality: 'high',
|
||||
numInferenceSteps: 30,
|
||||
guidanceScale: 4.5,
|
||||
promptEnhancement: true
|
||||
})
|
||||
)
|
||||
|
||||
const callOptions = mockGenerateImage.mock.calls[0]?.[2]
|
||||
expect(callOptions.experimental_download).toBeTypeOf('function')
|
||||
|
||||
const downloaded = await callOptions.experimental_download([
|
||||
{
|
||||
url: new URL('https://example.com/image.png'),
|
||||
isUrlSupportedByModel: false
|
||||
}
|
||||
])
|
||||
|
||||
expect(mockDownloadImageAsBase64).toHaveBeenCalledWith('https://example.com/image.png')
|
||||
expect(downloaded).toEqual([
|
||||
{
|
||||
data: Buffer.from('url-base64', 'base64'),
|
||||
mediaType: 'image/jpeg'
|
||||
}
|
||||
])
|
||||
|
||||
expect(createInternalEntry).toHaveBeenCalledWith({ source: 'base64', data: 'data:image/png;base64,abc123' })
|
||||
expect(result).toEqual({ files: [fileEntry] })
|
||||
})
|
||||
})
|
||||
|
||||
describe('AiService tool approval', () => {
|
||||
/** A fake renderer event whose `sender` satisfies `WebContentsListener`'s constructor. */
|
||||
function fakeEvent() {
|
||||
return {
|
||||
sender: {
|
||||
id: 1,
|
||||
once: vi.fn(),
|
||||
isDestroyed: () => false,
|
||||
send: vi.fn()
|
||||
}
|
||||
} as never
|
||||
}
|
||||
|
||||
/** A minimal `approval-requested` tool UI part (passes `isToolUIPart`). */
|
||||
function pendingToolPart(approvalId: string, toolName = 'mcp_write') {
|
||||
return {
|
||||
type: `tool-${toolName}`,
|
||||
toolCallId: `tc-${approvalId}`,
|
||||
state: 'approval-requested',
|
||||
input: {},
|
||||
approval: { id: approvalId }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Instantiate `AiService`, register its IPC handlers against the mocked
|
||||
* `ipcMain`, and return the captured `Ai_ToolApproval_Respond` listener.
|
||||
*/
|
||||
function getApprovalHandler() {
|
||||
const service = createService()
|
||||
;(service as unknown as { registerIpcHandlers(): void }).registerIpcHandlers()
|
||||
const call = vi
|
||||
.mocked(ipcMain.handle)
|
||||
.mock.calls.find(([channel]) => channel === IpcChannel.Ai_ToolApproval_Respond)
|
||||
if (!call) throw new Error('Ai_ToolApproval_Respond handler was not registered')
|
||||
return call[1] as (
|
||||
event: unknown,
|
||||
payload: {
|
||||
approvalId: string
|
||||
approved: boolean
|
||||
reason?: string
|
||||
updatedInput?: Record<string, unknown>
|
||||
topicId?: string
|
||||
anchorId?: string
|
||||
}
|
||||
) => Promise<{ ok: boolean }>
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('takes the Claude-Agent fast-path when the live registry dispatches the decision', async () => {
|
||||
const respondToolApproval = vi.fn(() => true)
|
||||
const dispatch = vi.fn()
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
const getById = vi.spyOn(messageService, 'getById')
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'agent-approval-1',
|
||||
approved: true
|
||||
})
|
||||
|
||||
expect(result).toEqual({ ok: true })
|
||||
expect(respondToolApproval).toHaveBeenCalledWith('agent-approval-1', {
|
||||
approved: true,
|
||||
reason: undefined,
|
||||
updatedInput: undefined
|
||||
})
|
||||
// Fast-path short-circuits before any DB read or continue dispatch.
|
||||
expect(getById).not.toHaveBeenCalled()
|
||||
expect(dispatch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns { ok: false } when there is no live entry and no anchor context', async () => {
|
||||
const respondToolApproval = vi.fn(() => false)
|
||||
mockApplicationGet.mockImplementation((name: string) =>
|
||||
name === 'AgentSessionRuntimeService' ? { respondToolApproval } : undefined
|
||||
)
|
||||
const getById = vi.spyOn(messageService, 'getById')
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'orphan-approval-1',
|
||||
approved: true
|
||||
// no topicId / anchorId
|
||||
})
|
||||
|
||||
expect(result).toEqual({ ok: false })
|
||||
expect(getById).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('persists the flipped parts and dispatches continue-conversation for an MCP approval present on the row', async () => {
|
||||
const respondToolApproval = vi.fn(() => false)
|
||||
const dispatch = vi.fn().mockResolvedValue(undefined)
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
|
||||
const beforeParts = [{ type: 'text', text: 'hello' }, pendingToolPart('mcp-approval-1')]
|
||||
vi.spyOn(messageService, 'getById').mockResolvedValue({ data: { parts: beforeParts } } as never)
|
||||
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'mcp-approval-1',
|
||||
approved: true,
|
||||
topicId: 'topic-1',
|
||||
anchorId: 'anchor-1'
|
||||
})
|
||||
|
||||
expect(result).toEqual({ ok: true })
|
||||
// Target part was on the row → write the flipped parts.
|
||||
expect(update).toHaveBeenCalledTimes(1)
|
||||
const [updatedId, updateDto] = update.mock.calls[0]
|
||||
expect(updatedId).toBe('anchor-1')
|
||||
const writtenParts = (updateDto as { data: { parts: Array<{ state?: string }> } }).data.parts
|
||||
expect(writtenParts[1].state).toBe('approval-responded')
|
||||
// Nothing left pending → resume via continue-conversation.
|
||||
expect(dispatch).toHaveBeenCalledTimes(1)
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
trigger: 'continue-conversation',
|
||||
topicId: 'topic-1',
|
||||
parentAnchorId: 'anchor-1',
|
||||
approvalDecisions: [{ approvalId: 'mcp-approval-1', approved: true }]
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('does not write parts when the approval is overlay-only (not present on the row) but still dispatches', async () => {
|
||||
const respondToolApproval = vi.fn(() => false)
|
||||
const dispatch = vi.fn().mockResolvedValue(undefined)
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
|
||||
// Row carries no approval-requested part matching this approvalId.
|
||||
vi.spyOn(messageService, 'getById').mockResolvedValue({
|
||||
data: { parts: [{ type: 'text', text: 'hello' }] }
|
||||
} as never)
|
||||
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'mcp-approval-missing',
|
||||
approved: false,
|
||||
topicId: 'topic-1',
|
||||
anchorId: 'anchor-1'
|
||||
})
|
||||
|
||||
expect(result).toEqual({ ok: true })
|
||||
// Part absent on the row → no overwrite of the persisted parts...
|
||||
expect(update).not.toHaveBeenCalled()
|
||||
// ...but the decision still rides the continue dispatch idempotently.
|
||||
expect(dispatch).toHaveBeenCalledTimes(1)
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
trigger: 'continue-conversation',
|
||||
approvalDecisions: [{ approvalId: 'mcp-approval-missing', approved: false }]
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('does not finalize while another approval on the turn is still pending', async () => {
|
||||
const respondToolApproval = vi.fn(() => false)
|
||||
const dispatch = vi.fn().mockResolvedValue(undefined)
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
|
||||
// Two outstanding approvals on the same row; we only decide the first.
|
||||
const beforeParts = [pendingToolPart('mcp-approval-1'), pendingToolPart('mcp-approval-2', 'mcp_read')]
|
||||
vi.spyOn(messageService, 'getById').mockResolvedValue({ data: { parts: beforeParts } } as never)
|
||||
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'mcp-approval-1',
|
||||
approved: true,
|
||||
topicId: 'topic-1',
|
||||
anchorId: 'anchor-1'
|
||||
})
|
||||
|
||||
expect(result).toEqual({ ok: true })
|
||||
// The decided part is persisted...
|
||||
expect(update).toHaveBeenCalledTimes(1)
|
||||
// ...but the still-pending sibling gates the resume.
|
||||
expect(dispatch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns { ok: false } when the anchor message is missing or deleted', async () => {
|
||||
const respondToolApproval = vi.fn(() => false)
|
||||
const dispatch = vi.fn().mockResolvedValue(undefined)
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
|
||||
// A stale click on a deleted message: getById rejects.
|
||||
const getById = vi.spyOn(messageService, 'getById').mockRejectedValue(new Error('Message not found'))
|
||||
const update = vi.spyOn(messageService, 'update')
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
const result = await handler(fakeEvent(), {
|
||||
approvalId: 'mcp-approval-1',
|
||||
approved: true,
|
||||
topicId: 'topic-1',
|
||||
anchorId: 'deleted-anchor'
|
||||
})
|
||||
|
||||
// Resolves gracefully through the documented result shape instead of throwing.
|
||||
expect(result).toEqual({ ok: false })
|
||||
expect(getById).toHaveBeenCalledWith('deleted-anchor')
|
||||
expect(update).not.toHaveBeenCalled()
|
||||
expect(dispatch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns { ok: false } when the IPC payload is invalid (rejected at the boundary)', async () => {
|
||||
const respondToolApproval = vi.fn(() => true)
|
||||
const dispatch = vi.fn()
|
||||
mockApplicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
|
||||
if (name === 'AiStreamManager') return { dispatch }
|
||||
return undefined
|
||||
})
|
||||
const getById = vi.spyOn(messageService, 'getById')
|
||||
|
||||
const handler = getApprovalHandler()
|
||||
// Missing `approved` boolean and empty `approvalId` → schema rejects.
|
||||
const result = await handler(fakeEvent(), { approvalId: '' } as never)
|
||||
|
||||
expect(result).toEqual({ ok: false })
|
||||
// Rejected before any registry dispatch or DB read.
|
||||
expect(respondToolApproval).not.toHaveBeenCalled()
|
||||
expect(getById).not.toHaveBeenCalled()
|
||||
expect(dispatch).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
12
src/main/ai/__tests__/fixtures/assistant.ts
Normal file
12
src/main/ai/__tests__/fixtures/assistant.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import type { Assistant, AssistantSettings } from '@shared/data/types/assistant'
|
||||
import { DEFAULT_ASSISTANT_SETTINGS } from '@shared/data/types/assistant'
|
||||
|
||||
export function makeAssistant(
|
||||
overrides: Partial<Omit<Assistant, 'settings'>> & { settings?: Partial<AssistantSettings> } = {}
|
||||
): Assistant {
|
||||
const { settings, ...rest } = overrides
|
||||
return {
|
||||
settings: { ...DEFAULT_ASSISTANT_SETTINGS, ...settings },
|
||||
...rest
|
||||
} as Assistant
|
||||
}
|
||||
3
src/main/ai/__tests__/fixtures/index.ts
Normal file
3
src/main/ai/__tests__/fixtures/index.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
export { makeAssistant } from './assistant'
|
||||
export { makeModel } from './model'
|
||||
export { makeEndpointConfig, makeProvider } from './provider'
|
||||
21
src/main/ai/__tests__/fixtures/model.ts
Normal file
21
src/main/ai/__tests__/fixtures/model.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { Model } from '@shared/data/types/model'
|
||||
|
||||
/**
|
||||
* Minimal valid Model fixture for main/ai tests.
|
||||
*
|
||||
* Defaults satisfy ModelSchema's required fields (id, providerId, name,
|
||||
* capabilities, supportsStreaming, isEnabled, isHidden). Pass overrides for
|
||||
* whatever the SUT actually reads.
|
||||
*/
|
||||
export function makeModel(overrides: Partial<Model> = {}): Model {
|
||||
return {
|
||||
id: 'openai::gpt-4',
|
||||
providerId: 'openai',
|
||||
name: 'GPT-4',
|
||||
capabilities: [],
|
||||
supportsStreaming: true,
|
||||
isEnabled: true,
|
||||
isHidden: false,
|
||||
...overrides
|
||||
} as Model
|
||||
}
|
||||
26
src/main/ai/__tests__/fixtures/provider.ts
Normal file
26
src/main/ai/__tests__/fixtures/provider.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { EndpointConfig, Provider } from '@shared/data/types/provider'
|
||||
import { DEFAULT_API_FEATURES, DEFAULT_PROVIDER_SETTINGS } from '@shared/data/types/provider'
|
||||
|
||||
/**
|
||||
* Minimal valid Provider fixture for main/ai tests.
|
||||
*
|
||||
* Defaults satisfy ProviderSchema's required fields (apiKeys, authType,
|
||||
* apiFeatures, settings, isEnabled). Pass overrides for whatever the SUT
|
||||
* actually reads.
|
||||
*/
|
||||
export function makeProvider(overrides: Partial<Provider> = {}): Provider {
|
||||
return {
|
||||
id: 'fake',
|
||||
name: 'Fake',
|
||||
apiKeys: [],
|
||||
authType: 'api-key',
|
||||
apiFeatures: { ...DEFAULT_API_FEATURES },
|
||||
settings: { ...DEFAULT_PROVIDER_SETTINGS },
|
||||
isEnabled: true,
|
||||
...overrides
|
||||
} as Provider
|
||||
}
|
||||
|
||||
export function makeEndpointConfig(overrides: Partial<EndpointConfig> = {}): EndpointConfig {
|
||||
return { ...overrides }
|
||||
}
|
||||
847
src/main/ai/agentSession/AgentSessionRuntimeService.ts
Normal file
847
src/main/ai/agentSession/AgentSessionRuntimeService.ts
Normal file
@@ -0,0 +1,847 @@
|
||||
import { agentService } from '@data/services/AgentService'
|
||||
import { agentSessionMessageService } from '@data/services/AgentSessionMessageService'
|
||||
import { loggerService } from '@logger'
|
||||
import { application } from '@main/core/application'
|
||||
import { BaseService, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
|
||||
import { topicNamingService } from '@main/services/TopicNamingService'
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import type { AgentEntity, AgentPermissionMode, UpdateAgentDto } from '@shared/data/api/schemas/agents'
|
||||
import type { AgentSessionMessageEntity } from '@shared/data/types/agent'
|
||||
import type { CherryUIMessage } from '@shared/data/types/message'
|
||||
import { parseUniqueModelId, type UniqueModelId } from '@shared/data/types/model'
|
||||
import { serializeError } from '@shared/types/error'
|
||||
import type { UIMessageChunk } from 'ai'
|
||||
import { v7 as uuidv7 } from 'uuid'
|
||||
|
||||
import { startAiTurnTrace } from '../observability'
|
||||
import {
|
||||
type AgentRuntimeConnection,
|
||||
type AgentRuntimeEvent,
|
||||
type AgentRuntimePolicyUpdate,
|
||||
type AgentRuntimeTraceContext,
|
||||
runtimeDriverRegistry
|
||||
} from '../runtime'
|
||||
import { type DispatchDecision, toolApprovalRegistry } from '../runtime/claudeCode/ToolApprovalRegistry'
|
||||
import { PersistenceListener } from '../streamManager/listeners/PersistenceListener'
|
||||
import { TraceFlushListener } from '../streamManager/listeners/TraceFlushListener'
|
||||
import type { StreamDoneResult, StreamErrorResult, StreamListener, StreamPausedResult } from '../streamManager/types'
|
||||
import { AgentSessionMessageBackend } from './persistence/AgentSessionMessageBackend'
|
||||
|
||||
const logger = loggerService.withContext('AgentSessionRuntimeService')
|
||||
const DEFAULT_IDLE_TTL_MS = 5 * 60 * 1000
|
||||
|
||||
export type AgentSessionRuntimeStatus = 'active' | 'idle'
|
||||
export type AgentSessionRuntimeTerminalStatus = 'success' | 'paused' | 'error'
|
||||
|
||||
/**
|
||||
* Why an in-flight turn is being stopped — selects how much runtime state to
|
||||
* tear down (see {@link STOP_POLICY}). Derived from the service's own typed
|
||||
* state (`turn.interruptRequested`), never from the abort signal's `reason`, so
|
||||
* a user Stop can't be misread as a steer interrupt. An abort with nothing
|
||||
* requested defaults to `user-stop` — the safe-failure direction (full teardown
|
||||
* closes the connection, killing the runtime/subagent).
|
||||
*/
|
||||
type AgentTurnStopIntent = 'interrupt' | 'user-stop'
|
||||
|
||||
interface TurnStopPolicy {
|
||||
/** Terminal status stamped on the turn when the session survives the stop. */
|
||||
turnStatus: AgentSessionRuntimeTerminalStatus
|
||||
/** Tear the whole session down (connection + entry), not just the turn. */
|
||||
closeSession: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* `interrupt` (steer): pause this turn but keep the connection + session so the
|
||||
* queued message can open the next turn (the runtime was gracefully interrupted,
|
||||
* not closed — the warm query survives). `user-stop`: tear the session down so
|
||||
* `connection.close()` kills the runtime query and its subagent.
|
||||
*/
|
||||
const STOP_POLICY: Record<AgentTurnStopIntent, TurnStopPolicy> = {
|
||||
interrupt: { turnStatus: 'paused', closeSession: false },
|
||||
'user-stop': { turnStatus: 'paused', closeSession: true }
|
||||
}
|
||||
|
||||
export interface BeginAgentSessionTurnInput {
|
||||
sessionId: string
|
||||
topicId: string
|
||||
agentId: string
|
||||
agentType: string
|
||||
modelId: UniqueModelId
|
||||
assistantMessageId?: string
|
||||
userMessage?: AgentSessionMessageEntity
|
||||
traceId?: string
|
||||
rootSpanId?: string
|
||||
}
|
||||
|
||||
export interface AgentSessionRuntimeHandle {
|
||||
listeners: StreamListener[]
|
||||
turnId: string
|
||||
}
|
||||
|
||||
export interface OpenAgentSessionTurnStreamInput {
|
||||
sessionId: string
|
||||
turnId: string
|
||||
signal: AbortSignal
|
||||
}
|
||||
|
||||
export interface AgentSessionRuntimeSnapshot {
|
||||
sessionId: string
|
||||
topicId?: string
|
||||
assistantMessageId?: string
|
||||
status: AgentSessionRuntimeStatus
|
||||
pendingMessageCount: number
|
||||
lastTerminalStatus?: AgentSessionRuntimeTerminalStatus
|
||||
resumeToken?: string
|
||||
activeToolCount: number
|
||||
interruptRequested: boolean
|
||||
}
|
||||
|
||||
type AgentSessionTurn = {
|
||||
turnId: string
|
||||
assistantMessageId?: string
|
||||
userMessage: AgentSessionMessageEntity
|
||||
modelId: UniqueModelId
|
||||
admitted: boolean
|
||||
terminalStatus?: AgentSessionRuntimeTerminalStatus
|
||||
controller?: ReadableStreamDefaultController<UIMessageChunk>
|
||||
activeToolIds: Set<string>
|
||||
interruptRequested: boolean
|
||||
trace?: AgentRuntimeTraceContext
|
||||
}
|
||||
|
||||
type AgentSessionRuntimeEntry = {
|
||||
sessionId: string
|
||||
topicId: string
|
||||
agentId: string
|
||||
agentType: string
|
||||
modelId: UniqueModelId
|
||||
status: AgentSessionRuntimeStatus
|
||||
pendingTurns: AgentSessionMessageEntity[]
|
||||
connection?: AgentRuntimeConnection
|
||||
connectionLoop?: Promise<void>
|
||||
/** In-flight {@link ensureConnection} promise — shared by concurrent callers so only one connect runs. */
|
||||
connecting?: Promise<boolean>
|
||||
currentTurn?: AgentSessionTurn
|
||||
lastResumeToken?: string
|
||||
lastTerminalStatus?: AgentSessionRuntimeTerminalStatus
|
||||
idleTimer?: ReturnType<typeof setTimeout>
|
||||
startingNextTurn?: boolean
|
||||
}
|
||||
|
||||
class AgentSessionRuntimeTerminalListener implements StreamListener {
|
||||
readonly id: string
|
||||
|
||||
constructor(
|
||||
private readonly service: AgentSessionRuntimeService,
|
||||
private readonly sessionId: string
|
||||
) {
|
||||
this.id = `agent-runtime:${sessionId}`
|
||||
}
|
||||
|
||||
onChunk(): void {}
|
||||
|
||||
onDone(result: StreamDoneResult): void {
|
||||
if (result.isTopicDone === false) return
|
||||
this.service.markTurnTerminal(this.sessionId, 'success')
|
||||
}
|
||||
|
||||
onPaused(result: StreamPausedResult): void {
|
||||
if (result.isTopicDone === false) return
|
||||
this.service.markTurnTerminal(this.sessionId, 'paused')
|
||||
}
|
||||
|
||||
onError(result: StreamErrorResult): void {
|
||||
if (result.isTopicDone === false) return
|
||||
this.service.markTurnTerminal(this.sessionId, 'error')
|
||||
}
|
||||
|
||||
isAlive(): boolean {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@Injectable('AgentSessionRuntimeService')
|
||||
@ServicePhase(Phase.WhenReady)
|
||||
export class AgentSessionRuntimeService extends BaseService {
|
||||
private readonly entries = new Map<string, AgentSessionRuntimeEntry>()
|
||||
|
||||
protected async onInit(): Promise<void> {
|
||||
// Resolve agent-session assistant rows a prior main-process crash left `pending` — at boot the
|
||||
// in-memory entry map is empty, so every such row is stale. Mirrors AiStreamManager's chat
|
||||
// reconcile so both message tables are settled on restart (neither stays a frozen "thinking"
|
||||
// bubble); agent sessions additionally recover conversation context via the resume token.
|
||||
await this.reconcileStalePendingMessages()
|
||||
|
||||
this.registerDisposable(
|
||||
agentService.onAgentUpdated(({ agentId, updates, agent }) => {
|
||||
void this.handleAgentUpdated(agentId, updates, agent).catch((error) => {
|
||||
logger.warn('Failed to apply live agent policy update', { agentId, error })
|
||||
})
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
private async reconcileStalePendingMessages(): Promise<void> {
|
||||
try {
|
||||
const staleIds = await agentSessionMessageService.findPendingAssistantMessageIds()
|
||||
if (staleIds.length === 0) return
|
||||
logger.info('Reconciling crash-orphaned pending agent-session messages', { count: staleIds.length })
|
||||
await agentSessionMessageService.markMessagesError(staleIds)
|
||||
} catch (error) {
|
||||
logger.error('Failed to reconcile stale pending agent-session messages', { error })
|
||||
}
|
||||
}
|
||||
|
||||
beginTurn(input: BeginAgentSessionTurnInput): AgentSessionRuntimeHandle {
|
||||
const turnId = crypto.randomUUID()
|
||||
const userMessage = input.userMessage ?? createSyntheticUserMessage(input.sessionId)
|
||||
const existing = this.entries.get(input.sessionId)
|
||||
const turn: AgentSessionTurn = {
|
||||
turnId,
|
||||
assistantMessageId: input.assistantMessageId,
|
||||
userMessage,
|
||||
modelId: input.modelId,
|
||||
admitted: false,
|
||||
activeToolIds: new Set(),
|
||||
interruptRequested: false,
|
||||
trace: this.createTraceContext(input, turnId, input.traceId, input.rootSpanId)
|
||||
}
|
||||
|
||||
if (existing?.status === 'idle') {
|
||||
this.clearIdleTimer(existing)
|
||||
existing.pendingTurns = []
|
||||
existing.topicId = input.topicId
|
||||
existing.agentId = input.agentId
|
||||
existing.agentType = input.agentType
|
||||
existing.modelId = input.modelId
|
||||
existing.status = 'active'
|
||||
existing.currentTurn = turn
|
||||
|
||||
return {
|
||||
listeners: [
|
||||
this.createPersistenceListener(existing, userMessage),
|
||||
new AgentSessionRuntimeTerminalListener(this, input.sessionId),
|
||||
new TraceFlushListener(input.topicId)
|
||||
],
|
||||
turnId
|
||||
}
|
||||
}
|
||||
|
||||
if (existing) this.closeSession(input.sessionId)
|
||||
|
||||
const entry: AgentSessionRuntimeEntry = {
|
||||
sessionId: input.sessionId,
|
||||
topicId: input.topicId,
|
||||
agentId: input.agentId,
|
||||
agentType: input.agentType,
|
||||
modelId: input.modelId,
|
||||
status: 'active',
|
||||
pendingTurns: [],
|
||||
currentTurn: turn
|
||||
}
|
||||
this.entries.set(input.sessionId, entry)
|
||||
|
||||
return {
|
||||
listeners: [
|
||||
this.createPersistenceListener(entry, userMessage),
|
||||
new AgentSessionRuntimeTerminalListener(this, input.sessionId),
|
||||
new TraceFlushListener(input.topicId)
|
||||
],
|
||||
turnId
|
||||
}
|
||||
}
|
||||
|
||||
async applyAgentPolicyUpdate(agentId: string, update: AgentRuntimePolicyUpdate): Promise<void> {
|
||||
const updates: Array<Promise<boolean> | boolean> = []
|
||||
for (const entry of this.entries.values()) {
|
||||
if (entry.agentId !== agentId || !entry.connection?.applyPolicyUpdate) continue
|
||||
updates.push(entry.connection.applyPolicyUpdate(update))
|
||||
}
|
||||
await Promise.allSettled(updates)
|
||||
}
|
||||
|
||||
private async handleAgentUpdated(agentId: string, updates: UpdateAgentDto, agent: AgentEntity): Promise<void> {
|
||||
const configuration = updates.configuration as { permission_mode?: unknown } | undefined
|
||||
const hasPermissionModeUpdate =
|
||||
configuration !== undefined && Object.prototype.hasOwnProperty.call(configuration, 'permission_mode')
|
||||
|
||||
if (hasPermissionModeUpdate) {
|
||||
await this.applyAgentPolicyUpdate(agentId, {
|
||||
type: 'permission-mode',
|
||||
permissionMode: configuration.permission_mode as AgentPermissionMode | undefined
|
||||
})
|
||||
}
|
||||
|
||||
if (
|
||||
Object.prototype.hasOwnProperty.call(updates, 'allowedTools') ||
|
||||
Object.prototype.hasOwnProperty.call(updates, 'mcps')
|
||||
) {
|
||||
await this.applyAgentPolicyUpdate(agentId, { type: 'tool-policy', agent })
|
||||
}
|
||||
}
|
||||
|
||||
openTurnStream(input: OpenAgentSessionTurnStreamInput): ReadableStream<UIMessageChunk> {
|
||||
const entry = this.entries.get(input.sessionId)
|
||||
const turn = entry?.currentTurn
|
||||
if (!entry || !turn || turn.turnId !== input.turnId) {
|
||||
throw new Error(`No active agent runtime turn ${input.turnId} for session ${input.sessionId}`)
|
||||
}
|
||||
|
||||
return new ReadableStream<UIMessageChunk>({
|
||||
start: async (controller) => {
|
||||
try {
|
||||
this.clearIdleTimer(entry)
|
||||
turn.controller = controller
|
||||
|
||||
const onAbort = () => this.stopTurn(entry, turn.interruptRequested ? 'interrupt' : 'user-stop')
|
||||
if (input.signal.aborted) {
|
||||
onAbort()
|
||||
return
|
||||
} else {
|
||||
input.signal.addEventListener('abort', onAbort, { once: true })
|
||||
}
|
||||
|
||||
controller.enqueue({ type: 'start' })
|
||||
const connected = await this.ensureConnection(entry)
|
||||
if (!connected || !this.isCurrentEntry(entry) || turn.terminalStatus) return
|
||||
await this.admitTurn(entry, turn)
|
||||
} catch (error) {
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
cancel: () => {
|
||||
this.closeCurrentTurn(entry, 'paused')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
enqueueUserMessage(sessionId: string, message: AgentSessionMessageEntity): void {
|
||||
const entry = this.entries.get(sessionId)
|
||||
if (!entry) return
|
||||
|
||||
entry.pendingTurns.push(message)
|
||||
entry.status = 'active'
|
||||
this.clearIdleTimer(entry)
|
||||
|
||||
const turn = entry.currentTurn
|
||||
if (!turn || turn.terminalStatus) {
|
||||
this.scheduleNextTurn(entry)
|
||||
return
|
||||
}
|
||||
|
||||
if (turn.activeToolIds.size > 0) return
|
||||
|
||||
queueMicrotask(() => {
|
||||
const latest = this.entries.get(sessionId)
|
||||
if (!latest?.currentTurn || latest.currentTurn.terminalStatus) {
|
||||
if (latest) this.scheduleNextTurn(latest)
|
||||
return
|
||||
}
|
||||
this.requestInterruptWhenSafe(latest)
|
||||
})
|
||||
}
|
||||
|
||||
markTurnTerminal(sessionId: string, status: AgentSessionRuntimeTerminalStatus): void {
|
||||
const entry = this.entries.get(sessionId)
|
||||
if (!entry) return
|
||||
|
||||
entry.status = 'idle'
|
||||
entry.lastTerminalStatus = status
|
||||
if (entry.currentTurn) entry.currentTurn.terminalStatus = status
|
||||
|
||||
if (this.shouldCloseConnectionAfterTurn(entry)) {
|
||||
// close() may be async on some drivers; swallow rejection so it can't become unhandled.
|
||||
void Promise.resolve(this.closeConnection(entry)?.close()).catch((error) =>
|
||||
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
|
||||
)
|
||||
}
|
||||
|
||||
if (entry.pendingTurns.length > 0) {
|
||||
this.scheduleNextTurn(entry)
|
||||
} else {
|
||||
this.refreshIdleTimer(entry)
|
||||
}
|
||||
}
|
||||
|
||||
closeSession(sessionId: string): void {
|
||||
const entry = this.entries.get(sessionId)
|
||||
if (!entry) return
|
||||
this.entries.delete(sessionId)
|
||||
this.closeEntry(entry)
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the session has a turn in flight or about to start: a non-terminal current turn,
|
||||
* a next-turn drain in progress (`startingNextTurn`), or queued follow-ups. The dispatcher
|
||||
* uses this — NOT `AiStreamManager.hasLiveStream` — to decide enqueue-vs-begin, because
|
||||
* `hasLiveStream` is false during the inter-turn drain window while the entry is still
|
||||
* mid-transition; a fresh dispatch trusting `hasLiveStream` there would clobber the drain via
|
||||
* `beginTurn`.
|
||||
*/
|
||||
isSessionBusy(sessionId: string): boolean {
|
||||
const entry = this.entries.get(sessionId)
|
||||
if (!entry) return false
|
||||
return (
|
||||
entry.startingNextTurn === true ||
|
||||
entry.pendingTurns.length > 0 ||
|
||||
(entry.currentTurn !== undefined && entry.currentTurn.terminalStatus === undefined)
|
||||
)
|
||||
}
|
||||
|
||||
inspect(sessionId: string): AgentSessionRuntimeSnapshot | undefined {
|
||||
const entry = this.entries.get(sessionId)
|
||||
if (!entry) return undefined
|
||||
const turn = entry.currentTurn
|
||||
|
||||
return {
|
||||
sessionId: entry.sessionId,
|
||||
topicId: entry.topicId,
|
||||
assistantMessageId: turn?.assistantMessageId,
|
||||
status: entry.status,
|
||||
pendingMessageCount: entry.pendingTurns.length,
|
||||
lastTerminalStatus: entry.lastTerminalStatus,
|
||||
resumeToken: entry.lastResumeToken,
|
||||
activeToolCount: turn?.activeToolIds.size ?? 0,
|
||||
interruptRequested: turn?.interruptRequested ?? false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a Claude `canUseTool` approval that was registered against the live
|
||||
* driver session. Returns `false` if no live entry matches — the caller
|
||||
* falls back to MCP/DB path.
|
||||
*/
|
||||
respondToolApproval(approvalId: string, decision: DispatchDecision): boolean {
|
||||
return toolApprovalRegistry.dispatch(approvalId, decision)
|
||||
}
|
||||
|
||||
protected onStop(): void {
|
||||
this.closeAll()
|
||||
toolApprovalRegistry.clear('agent-session-runtime-stop')
|
||||
}
|
||||
|
||||
protected onDestroy(): void {
|
||||
this.closeAll()
|
||||
toolApprovalRegistry.clear('agent-session-runtime-destroy')
|
||||
}
|
||||
|
||||
private isCurrentEntry(entry: AgentSessionRuntimeEntry): boolean {
|
||||
return this.entries.get(entry.sessionId) === entry
|
||||
}
|
||||
|
||||
private async ensureConnection(entry: AgentSessionRuntimeEntry): Promise<boolean> {
|
||||
if (!this.isCurrentEntry(entry)) return false
|
||||
if (entry.connection) return true
|
||||
// Share a single in-flight connect across concurrent callers so two streams opening at once
|
||||
// can't each spin up a connection (the second would leak/clobber the first).
|
||||
if (entry.connecting) return entry.connecting
|
||||
|
||||
const connecting = this.connect(entry).finally(() => {
|
||||
if (entry.connecting === connecting) entry.connecting = undefined
|
||||
})
|
||||
entry.connecting = connecting
|
||||
return connecting
|
||||
}
|
||||
|
||||
private async connect(entry: AgentSessionRuntimeEntry): Promise<boolean> {
|
||||
const driver = runtimeDriverRegistry.getAgentSessionDriver(entry.agentType)
|
||||
if (!driver) throw new Error(`Unsupported agent runtime type: ${entry.agentType}`)
|
||||
|
||||
await this.hydrateResumeToken(entry)
|
||||
if (!this.isCurrentEntry(entry)) return false
|
||||
|
||||
const connection = await driver.connect({
|
||||
sessionId: entry.sessionId,
|
||||
agentId: entry.agentId,
|
||||
modelId: entry.modelId,
|
||||
resumeToken: entry.lastResumeToken,
|
||||
trace: entry.currentTurn?.trace
|
||||
})
|
||||
if (!this.isCurrentEntry(entry)) {
|
||||
void Promise.resolve(connection.close()).catch((error) =>
|
||||
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
entry.connection = connection
|
||||
entry.connectionLoop = this.runConnectionLoop(entry, connection).finally(() => {
|
||||
if (entry.connection === connection) entry.connection = undefined
|
||||
if (entry.connectionLoop) entry.connectionLoop = undefined
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
private async hydrateResumeToken(entry: AgentSessionRuntimeEntry): Promise<void> {
|
||||
if (entry.lastResumeToken) return
|
||||
const runtimeResumeToken = await agentSessionMessageService.getLastRuntimeResumeToken(entry.sessionId)
|
||||
if (runtimeResumeToken) entry.lastResumeToken = runtimeResumeToken
|
||||
}
|
||||
|
||||
private async runConnectionLoop(entry: AgentSessionRuntimeEntry, connection: AgentRuntimeConnection): Promise<void> {
|
||||
try {
|
||||
for await (const event of connection.events) {
|
||||
this.handleRuntimeEvent(entry, event)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRuntimeError(entry, error)
|
||||
}
|
||||
}
|
||||
|
||||
private handleRuntimeEvent(entry: AgentSessionRuntimeEntry, event: AgentRuntimeEvent): void {
|
||||
switch (event.type) {
|
||||
case 'resume-token':
|
||||
entry.lastResumeToken = event.token
|
||||
break
|
||||
case 'chunk': {
|
||||
const turn = entry.currentTurn
|
||||
if (turn?.controller && !turn.terminalStatus) this.enqueueTurnChunk(entry, turn, event.chunk)
|
||||
break
|
||||
}
|
||||
case 'turn-complete':
|
||||
this.closeCurrentTurn(entry, 'success')
|
||||
break
|
||||
case 'error':
|
||||
this.handleRuntimeError(entry, event.error)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
private handleRuntimeError(entry: AgentSessionRuntimeEntry, error: unknown): void {
|
||||
const turn = entry.currentTurn
|
||||
if (turn?.controller && !turn.terminalStatus) {
|
||||
turn.controller.error(error)
|
||||
// Mark terminal synchronously: the listener's markTurnTerminal arrives async (after the
|
||||
// stream error propagates), so a trailing `chunk` event in the same connection loop would
|
||||
// otherwise hit enqueueTurnChunk and throw on the now-errored controller.
|
||||
turn.terminalStatus = 'error'
|
||||
} else if (isAbortError(error)) {
|
||||
// Expected when a turn was interrupted/closed — the connection ending is not a fault.
|
||||
logger.warn('Agent runtime connection ended without an active turn', { sessionId: entry.sessionId, error })
|
||||
} else {
|
||||
// No turn to surface this on, so a real runtime failure would otherwise vanish — log it loudly
|
||||
// so the next reconnect-into-the-same-failure is at least traceable.
|
||||
logger.error('Agent runtime connection ended without an active turn', { sessionId: entry.sessionId, error })
|
||||
}
|
||||
}
|
||||
|
||||
private async admitTurn(entry: AgentSessionRuntimeEntry, turn: AgentSessionTurn): Promise<void> {
|
||||
if (!this.isCurrentEntry(entry) || entry.currentTurn !== turn || turn.terminalStatus) return
|
||||
if (turn.admitted) return
|
||||
turn.admitted = true
|
||||
entry.status = 'active'
|
||||
await entry.connection?.send({ message: turn.userMessage })
|
||||
if (entry.pendingTurns.length > 0) {
|
||||
queueMicrotask(() => this.requestInterruptWhenSafe(entry))
|
||||
}
|
||||
}
|
||||
|
||||
private enqueueTurnChunk(entry: AgentSessionRuntimeEntry, turn: AgentSessionTurn, chunk: UIMessageChunk): void {
|
||||
const toolChunk = chunk as { type?: string; toolCallId?: string }
|
||||
if ((toolChunk.type === 'tool-input-start' || toolChunk.type === 'tool-input-available') && toolChunk.toolCallId) {
|
||||
turn.activeToolIds.add(toolChunk.toolCallId)
|
||||
} else if (
|
||||
(toolChunk.type === 'tool-output-available' ||
|
||||
toolChunk.type === 'tool-output-error' ||
|
||||
toolChunk.type === 'tool-output-denied') &&
|
||||
toolChunk.toolCallId
|
||||
) {
|
||||
turn.activeToolIds.delete(toolChunk.toolCallId)
|
||||
}
|
||||
|
||||
turn.controller?.enqueue(chunk)
|
||||
|
||||
if (turn.activeToolIds.size === 0 && entry.pendingTurns.length > 0) this.requestInterruptWhenSafe(entry)
|
||||
}
|
||||
|
||||
private requestInterruptWhenSafe(entry: AgentSessionRuntimeEntry): void {
|
||||
const turn = entry.currentTurn
|
||||
if (!turn || turn.terminalStatus || !turn.admitted || turn.interruptRequested) return
|
||||
const canInterrupt = entry.connection?.canInterruptNow?.() ?? turn.activeToolIds.size === 0
|
||||
if (!canInterrupt) return
|
||||
turn.interruptRequested = true
|
||||
this.interruptCurrentTurn(entry)
|
||||
}
|
||||
|
||||
private interruptCurrentTurn(entry: AgentSessionRuntimeEntry): void {
|
||||
const turn = entry.currentTurn
|
||||
if (!turn || turn.terminalStatus) return
|
||||
void entry.connection?.interrupt?.().catch((error) => {
|
||||
logger.warn('Agent runtime interrupt failed', { sessionId: entry.sessionId, error })
|
||||
})
|
||||
application.get('AiStreamManager').pauseRuntimeTurn(entry.topicId, 'agent-runtime-interrupt')
|
||||
}
|
||||
|
||||
private stopTurn(entry: AgentSessionRuntimeEntry, intent: AgentTurnStopIntent): void {
|
||||
const policy = STOP_POLICY[intent]
|
||||
if (policy.closeSession) {
|
||||
this.closeSession(entry.sessionId)
|
||||
return
|
||||
}
|
||||
this.closeCurrentTurn(entry, policy.turnStatus)
|
||||
}
|
||||
|
||||
private closeCurrentTurn(entry: AgentSessionRuntimeEntry, status: AgentSessionRuntimeTerminalStatus): void {
|
||||
const turn = entry.currentTurn
|
||||
if (!turn || turn.terminalStatus) return
|
||||
turn.terminalStatus = status
|
||||
try {
|
||||
turn.controller?.close()
|
||||
} catch {
|
||||
// Already closed by the stream reader.
|
||||
}
|
||||
turn.controller = undefined
|
||||
turn.activeToolIds.clear()
|
||||
}
|
||||
|
||||
private scheduleNextTurn(entry: AgentSessionRuntimeEntry): void {
|
||||
if (entry.startingNextTurn) return
|
||||
entry.startingNextTurn = true
|
||||
// Keep `startingNextTurn` set for the WHOLE drain — `startNextTurn` spans a DB round-trip,
|
||||
// and `isSessionBusy` relies on this flag so a concurrent dispatch landing in the inter-turn
|
||||
// window enqueues instead of beginning a clobbering fresh turn. Clear it only once the drain
|
||||
// settles (turn established, bailed, or errored).
|
||||
queueMicrotask(() => {
|
||||
void this.startNextTurn(entry)
|
||||
.catch((error) => {
|
||||
logger.error('Failed to start next agent runtime turn', { sessionId: entry.sessionId, error })
|
||||
})
|
||||
.finally(() => {
|
||||
entry.startingNextTurn = false
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
private async startNextTurn(entry: AgentSessionRuntimeEntry): Promise<void> {
|
||||
const nextMessage = entry.pendingTurns.shift()
|
||||
if (!nextMessage) {
|
||||
this.refreshIdleTimer(entry)
|
||||
return
|
||||
}
|
||||
|
||||
const { rootSpan, traceId, rootSpanId } = this.startRuntimeRootSpan(entry)
|
||||
let assistantMessage: Awaited<ReturnType<typeof agentSessionMessageService.saveMessage>>
|
||||
try {
|
||||
assistantMessage = await agentSessionMessageService.saveMessage({
|
||||
sessionId: entry.sessionId,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
status: 'pending',
|
||||
data: { parts: [] },
|
||||
modelId: entry.modelId,
|
||||
traceId
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
// The placeholder save failed, so there is no assistant row to drive to `error` and no
|
||||
// point re-queuing the message — the retry would just fail the same way, and a re-queued
|
||||
// message is silently cleared by the idle TTL anyway. Instead surface the failure to the
|
||||
// live renderer and settle the turn so the session doesn't sit idle on a doomed message.
|
||||
rootSpan.end()
|
||||
application.get('AiStreamManager').broadcastTopicError(entry.topicId, entry.modelId, serializeError(error))
|
||||
this.markTurnTerminal(entry.sessionId, 'error')
|
||||
return
|
||||
}
|
||||
|
||||
// The DB save above yields the event loop; the session may have been torn down
|
||||
// (shutdown / a fresh beginTurn) in the meantime. Re-check before mutating the entry,
|
||||
// mirroring every other async method here — otherwise a dead entry gets resurrected
|
||||
// into a doomed runtime turn with no backing agent connection.
|
||||
if (!this.isCurrentEntry(entry)) {
|
||||
rootSpan.end()
|
||||
return
|
||||
}
|
||||
|
||||
const assistantMessageId = assistantMessage.id
|
||||
|
||||
const turnId = crypto.randomUUID()
|
||||
entry.currentTurn = {
|
||||
turnId,
|
||||
assistantMessageId,
|
||||
userMessage: nextMessage,
|
||||
modelId: entry.modelId,
|
||||
admitted: false,
|
||||
activeToolIds: new Set(),
|
||||
interruptRequested: false,
|
||||
trace: this.createTraceContext(entry, turnId, traceId, rootSpanId)
|
||||
}
|
||||
|
||||
application.get('AiStreamManager').startRuntimeTurn({
|
||||
topicId: entry.topicId,
|
||||
modelId: entry.modelId,
|
||||
rootSpan,
|
||||
request: {
|
||||
chatId: entry.topicId,
|
||||
trigger: 'submit-message',
|
||||
messageId: assistantMessageId,
|
||||
messages: createRuntimeSeedMessages(nextMessage, assistantMessageId),
|
||||
runtime: { kind: 'agent-session', sessionId: entry.sessionId, turnId }
|
||||
},
|
||||
listeners: [
|
||||
this.createPersistenceListener(entry, nextMessage),
|
||||
new AgentSessionRuntimeTerminalListener(this, entry.sessionId),
|
||||
new TraceFlushListener(entry.topicId)
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
private startRuntimeRootSpan(entry: AgentSessionRuntimeEntry): {
|
||||
rootSpan: Span
|
||||
traceId: string
|
||||
rootSpanId: string
|
||||
} {
|
||||
const turnTrace = startAiTurnTrace(
|
||||
'chat.turn',
|
||||
{
|
||||
attributes: {
|
||||
'cs.topic_id': entry.topicId,
|
||||
'cs.trigger': 'submit-message',
|
||||
'cs.model_id': entry.modelId,
|
||||
'cs.role': 'assistant',
|
||||
'cs.agent_id': entry.agentId,
|
||||
'cs.session_id': entry.sessionId
|
||||
}
|
||||
},
|
||||
{ topicId: entry.topicId, modelName: parseUniqueModelId(entry.modelId).modelId }
|
||||
)
|
||||
return { rootSpan: turnTrace.rootSpan, traceId: turnTrace.traceId, rootSpanId: turnTrace.rootSpanId }
|
||||
}
|
||||
|
||||
private createTraceContext(
|
||||
input: Pick<BeginAgentSessionTurnInput, 'topicId' | 'sessionId' | 'modelId'>,
|
||||
turnId: string,
|
||||
traceId?: string,
|
||||
rootSpanId?: string
|
||||
): AgentRuntimeTraceContext | undefined {
|
||||
if (!traceId || !rootSpanId) return undefined
|
||||
return {
|
||||
topicId: input.topicId,
|
||||
traceId,
|
||||
rootSpanId,
|
||||
sessionId: input.sessionId,
|
||||
turnId,
|
||||
modelName: parseUniqueModelId(input.modelId).modelId
|
||||
}
|
||||
}
|
||||
|
||||
private shouldCloseConnectionAfterTurn(entry: AgentSessionRuntimeEntry): boolean {
|
||||
return entry.connection?.shouldCloseAfterTurn?.() ?? false
|
||||
}
|
||||
|
||||
private createPersistenceListener(
|
||||
entry: AgentSessionRuntimeEntry,
|
||||
userMessage: AgentSessionMessageEntity
|
||||
): StreamListener {
|
||||
const userText = extractMessageText(userMessage)
|
||||
return new PersistenceListener({
|
||||
topicId: entry.topicId,
|
||||
modelId: entry.modelId,
|
||||
backend: new AgentSessionMessageBackend({
|
||||
sessionId: entry.sessionId,
|
||||
modelId: entry.modelId,
|
||||
runtimeResumeToken: () => entry.lastResumeToken,
|
||||
afterPersist: async (finalMessage) => {
|
||||
await topicNamingService.maybeRenameAgentSession(entry.agentId, entry.sessionId, userText, finalMessage)
|
||||
}
|
||||
}),
|
||||
onPersistFailed: (error) =>
|
||||
application.get('AiStreamManager').broadcastTopicError(entry.topicId, entry.modelId, error)
|
||||
})
|
||||
}
|
||||
|
||||
private refreshIdleTimer(entry: AgentSessionRuntimeEntry): void {
|
||||
this.clearIdleTimer(entry)
|
||||
entry.idleTimer = setTimeout(() => {
|
||||
const { sessionId, agentType, lastResumeToken } = entry
|
||||
this.closeSession(sessionId)
|
||||
if (lastResumeToken) {
|
||||
runtimeDriverRegistry.getAgentSessionDriver(agentType)?.onSessionIdle?.(sessionId)
|
||||
}
|
||||
}, DEFAULT_IDLE_TTL_MS)
|
||||
entry.idleTimer.unref?.()
|
||||
}
|
||||
|
||||
private clearIdleTimer(entry: AgentSessionRuntimeEntry): void {
|
||||
if (entry.idleTimer) {
|
||||
clearTimeout(entry.idleTimer)
|
||||
entry.idleTimer = undefined
|
||||
}
|
||||
}
|
||||
|
||||
private closeAll(): void {
|
||||
for (const sessionId of [...this.entries.keys()]) {
|
||||
this.closeSession(sessionId)
|
||||
}
|
||||
}
|
||||
|
||||
private closeEntry(entry: AgentSessionRuntimeEntry): void {
|
||||
this.clearIdleTimer(entry)
|
||||
this.closeCurrentTurn(entry, 'paused')
|
||||
entry.pendingTurns = []
|
||||
|
||||
const connection = this.closeConnection(entry)
|
||||
entry.currentTurn = undefined
|
||||
entry.startingNextTurn = false
|
||||
|
||||
void Promise.resolve(connection?.close()).catch((error) =>
|
||||
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
|
||||
)
|
||||
}
|
||||
|
||||
private closeConnection(entry: AgentSessionRuntimeEntry): AgentRuntimeConnection | undefined {
|
||||
const connection = entry.connection
|
||||
entry.connection = undefined
|
||||
entry.connectionLoop = undefined
|
||||
return connection
|
||||
}
|
||||
}
|
||||
|
||||
function isAbortError(error: unknown): boolean {
|
||||
return !!error && typeof error === 'object' && 'name' in error && (error as { name: unknown }).name === 'AbortError'
|
||||
}
|
||||
|
||||
function createRuntimeSeedMessages(
|
||||
userMessage: AgentSessionMessageEntity,
|
||||
assistantMessageId: string
|
||||
): CherryUIMessage[] {
|
||||
return [
|
||||
{
|
||||
id: userMessage.id,
|
||||
role: 'user',
|
||||
parts: userMessage.data?.parts ?? []
|
||||
},
|
||||
{
|
||||
id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
parts: []
|
||||
}
|
||||
] as CherryUIMessage[]
|
||||
}
|
||||
|
||||
function createSyntheticUserMessage(sessionId: string): AgentSessionMessageEntity {
|
||||
const now = new Date().toISOString()
|
||||
return {
|
||||
id: uuidv7(),
|
||||
sessionId,
|
||||
role: 'user',
|
||||
data: { parts: [] },
|
||||
status: 'success',
|
||||
searchableText: '',
|
||||
modelId: null,
|
||||
modelSnapshot: null,
|
||||
traceId: null,
|
||||
stats: null,
|
||||
runtimeResumeToken: null,
|
||||
createdAt: now,
|
||||
updatedAt: now
|
||||
}
|
||||
}
|
||||
|
||||
function extractMessageText(message: AgentSessionMessageEntity): string {
|
||||
return (
|
||||
message.data?.parts
|
||||
?.filter((part): part is { type: 'text'; text: string } => part.type === 'text' && 'text' in part)
|
||||
.map((part) => part.text)
|
||||
.join('\n') ?? ''
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,940 @@
|
||||
import { BaseService } from '@main/core/lifecycle/BaseService'
|
||||
import { mockMainLoggerService } from '@test-mocks/MainLoggerService'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
saveMessage: vi.fn(),
|
||||
getLastRuntimeResumeToken: vi.fn(),
|
||||
findPendingAssistantMessageIds: vi.fn(),
|
||||
markMessagesError: vi.fn(),
|
||||
maybeRenameAgentSession: vi.fn(),
|
||||
applicationGet: vi.fn(),
|
||||
startRuntimeTurn: vi.fn(),
|
||||
pauseRuntimeTurn: vi.fn(),
|
||||
broadcastTopicError: vi.fn(),
|
||||
spanCacheSetTopicId: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@data/services/AgentSessionMessageService', () => ({
|
||||
agentSessionMessageService: {
|
||||
saveMessage: mocks.saveMessage,
|
||||
getLastRuntimeResumeToken: mocks.getLastRuntimeResumeToken,
|
||||
findPendingAssistantMessageIds: mocks.findPendingAssistantMessageIds,
|
||||
markMessagesError: mocks.markMessagesError
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@main/services/TopicNamingService', () => ({
|
||||
topicNamingService: { maybeRenameAgentSession: mocks.maybeRenameAgentSession }
|
||||
}))
|
||||
|
||||
vi.mock('@main/core/application', () => ({
|
||||
application: { get: mocks.applicationGet }
|
||||
}))
|
||||
|
||||
const { AgentSessionRuntimeService } = await import('../AgentSessionRuntimeService')
|
||||
const { runtimeDriverRegistry } = await import('../../runtime')
|
||||
const baseTurnInput = {
|
||||
sessionId: 'session-1',
|
||||
topicId: 'agent-session:session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'test-runtime',
|
||||
modelId: 'claude-code::claude-sonnet-4-5' as any,
|
||||
assistantMessageId: 'assistant-1'
|
||||
}
|
||||
|
||||
function userMessage(id: string) {
|
||||
return {
|
||||
id,
|
||||
topicId: 'agent-session:session-1',
|
||||
parentId: null,
|
||||
role: 'user',
|
||||
data: { parts: [{ type: 'text', text: 'hello' }] },
|
||||
status: 'success',
|
||||
createdAt: '',
|
||||
updatedAt: ''
|
||||
} as any
|
||||
}
|
||||
|
||||
function terminalListener(handle: { listeners: any[] }) {
|
||||
const listener = handle.listeners.find((item) => item.id === 'agent-runtime:session-1')
|
||||
if (!listener) throw new Error('terminal listener missing')
|
||||
return listener
|
||||
}
|
||||
|
||||
function persistenceListener(handle: { listeners: any[] }) {
|
||||
const listener = handle.listeners.find((item) => String(item.id).startsWith('persistence:agents-db:'))
|
||||
if (!listener) throw new Error('persistence listener missing')
|
||||
return listener
|
||||
}
|
||||
|
||||
function getEntry(service: InstanceType<typeof AgentSessionRuntimeService>) {
|
||||
return (service as any).entries.get('session-1')
|
||||
}
|
||||
|
||||
function createAsyncQueue<T>() {
|
||||
const items: T[] = []
|
||||
const waiters: Array<(value: IteratorResult<T>) => void> = []
|
||||
|
||||
return {
|
||||
push(item: T) {
|
||||
const waiter = waiters.shift()
|
||||
if (waiter) waiter({ value: item, done: false })
|
||||
else items.push(item)
|
||||
},
|
||||
iterable: {
|
||||
[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
return {
|
||||
next: () => {
|
||||
const item = items.shift()
|
||||
if (item) return Promise.resolve({ value: item, done: false })
|
||||
return new Promise<IteratorResult<T>>((resolve) => waiters.push(resolve))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function createDeferred<T>() {
|
||||
let resolve!: (value: T) => void
|
||||
let reject!: (reason?: unknown) => void
|
||||
const promise = new Promise<T>((resolvePromise, rejectPromise) => {
|
||||
resolve = resolvePromise
|
||||
reject = rejectPromise
|
||||
})
|
||||
return { promise, resolve, reject }
|
||||
}
|
||||
|
||||
describe('AgentSessionRuntimeService', () => {
|
||||
beforeEach(() => {
|
||||
BaseService.resetInstances()
|
||||
runtimeDriverRegistry.clearForTest()
|
||||
vi.clearAllMocks()
|
||||
mocks.saveMessage.mockImplementation(async ({ message }) => ({
|
||||
...message,
|
||||
id: message.id ?? 'generated-message-id'
|
||||
}))
|
||||
mocks.getLastRuntimeResumeToken.mockResolvedValue(null)
|
||||
mocks.findPendingAssistantMessageIds.mockResolvedValue([])
|
||||
mocks.markMessagesError.mockResolvedValue(undefined)
|
||||
mocks.applicationGet.mockImplementation((name: string) => {
|
||||
if (name === 'AiStreamManager') {
|
||||
return {
|
||||
startRuntimeTurn: mocks.startRuntimeTurn,
|
||||
pauseRuntimeTurn: mocks.pauseRuntimeTurn,
|
||||
broadcastTopicError: mocks.broadcastTopicError
|
||||
}
|
||||
}
|
||||
if (name === 'SpanCacheService') return { setTopicId: mocks.spanCacheSetTopicId }
|
||||
throw new Error(`Unexpected application.get(${name})`)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isSessionBusy — inter-turn drain window (issue ①)', () => {
|
||||
it('is false with no entry and true while a turn is live', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
expect(service.isSessionBusy('session-1')).toBe(false)
|
||||
service.beginTurn(baseTurnInput)
|
||||
expect(service.isSessionBusy('session-1')).toBe(true)
|
||||
})
|
||||
|
||||
it('is false once a turn settles with no queued follow-ups', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
service.markTurnTerminal('session-1', 'success')
|
||||
expect(service.isSessionBusy('session-1')).toBe(false)
|
||||
})
|
||||
|
||||
it('stays busy throughout the next-turn drain, closing the clobber window', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
service.enqueueUserMessage('session-1', userMessage('user-2'))
|
||||
|
||||
// Hold the drain's assistant-placeholder save so we can observe the in-flight window.
|
||||
const deferred = createDeferred<any>()
|
||||
mocks.saveMessage.mockImplementationOnce(() => deferred.promise)
|
||||
|
||||
service.markTurnTerminal('session-1', 'success') // current turn → terminal, schedules the drain
|
||||
await new Promise((resolve) => setTimeout(resolve, 0)) // flush microtasks → drain parks on saveMessage
|
||||
|
||||
const entry = getEntry(service)
|
||||
// The bug window: the queued message was shifted (pendingTurns empty) and the old turn is
|
||||
// terminal — pre-fix nothing reported the session busy here.
|
||||
expect(entry.pendingTurns.length).toBe(0)
|
||||
expect(entry.currentTurn.terminalStatus).toBe('success')
|
||||
expect(entry.startingNextTurn).toBe(true) // flag now spans the whole drain
|
||||
expect(service.isSessionBusy('session-1')).toBe(true)
|
||||
|
||||
deferred.resolve({ id: 'assistant-2' })
|
||||
await new Promise((resolve) => setTimeout(resolve, 0)) // drain completes → fresh live turn
|
||||
expect(service.isSessionBusy('session-1')).toBe(true)
|
||||
expect(getEntry(service).startingNextTurn).toBe(false)
|
||||
})
|
||||
|
||||
it('does not resurrect a session torn down during the next-turn placeholder save (REGRESSION agent-session-1)', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
service.enqueueUserMessage('session-1', userMessage('user-2'))
|
||||
|
||||
// Hold the drain's placeholder save so we can tear the session down mid-await.
|
||||
const deferred = createDeferred<any>()
|
||||
mocks.saveMessage.mockImplementationOnce(() => deferred.promise)
|
||||
|
||||
service.markTurnTerminal('session-1', 'success') // schedules the drain → parks on saveMessage
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
const startCallsBefore = mocks.startRuntimeTurn.mock.calls.length
|
||||
|
||||
// Session is torn down (shutdown / a fresh beginTurn) while the save is still in flight.
|
||||
service.closeSession('session-1')
|
||||
|
||||
deferred.resolve({ id: 'assistant-2' })
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
|
||||
// The dead entry must NOT be resurrected into a runtime turn.
|
||||
expect(mocks.startRuntimeTurn.mock.calls.length).toBe(startCallsBefore)
|
||||
expect(getEntry(service)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('reconcileStalePendingMessages — boot crash recovery', () => {
|
||||
it('marks crash-orphaned pending assistant messages as errored on init', async () => {
|
||||
mocks.findPendingAssistantMessageIds.mockResolvedValue(['stale-1', 'stale-2'])
|
||||
const service = new AgentSessionRuntimeService()
|
||||
|
||||
await (service as any).onInit()
|
||||
|
||||
expect(mocks.findPendingAssistantMessageIds).toHaveBeenCalledOnce()
|
||||
expect(mocks.markMessagesError).toHaveBeenCalledWith(['stale-1', 'stale-2'])
|
||||
})
|
||||
|
||||
it('does not mark anything when there are no stale messages', async () => {
|
||||
mocks.findPendingAssistantMessageIds.mockResolvedValue([])
|
||||
const service = new AgentSessionRuntimeService()
|
||||
|
||||
await (service as any).onInit()
|
||||
|
||||
expect(mocks.markMessagesError).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('logs and does not rethrow when the reconcile lookup throws, so boot is not blocked', async () => {
|
||||
const failure = new Error('db down')
|
||||
mocks.findPendingAssistantMessageIds.mockRejectedValue(failure)
|
||||
const service = new AgentSessionRuntimeService()
|
||||
|
||||
await expect((service as any).onInit()).resolves.toBeUndefined()
|
||||
|
||||
expect(mocks.markMessagesError).not.toHaveBeenCalled()
|
||||
expect(mockMainLoggerService.error).toHaveBeenCalledWith(
|
||||
'Failed to reconcile stale pending agent-session messages',
|
||||
{ error: failure }
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('creates an active runtime with a session-level pending queue', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
|
||||
const handle = service.beginTurn(baseTurnInput)
|
||||
service.enqueueUserMessage('session-1', userMessage('user-2'))
|
||||
|
||||
expect(terminalListener(handle).id).toBe('agent-runtime:session-1')
|
||||
expect(persistenceListener(handle).id).toContain('persistence:agents-db:agent-session:session-1')
|
||||
expect(service.inspect('session-1')).toMatchObject({
|
||||
sessionId: 'session-1',
|
||||
topicId: 'agent-session:session-1',
|
||||
assistantMessageId: 'assistant-1',
|
||||
status: 'active',
|
||||
pendingMessageCount: 1,
|
||||
lastTerminalStatus: undefined,
|
||||
activeToolCount: 0,
|
||||
interruptRequested: false
|
||||
})
|
||||
})
|
||||
|
||||
it('marks the runtime idle when the terminal listener observes done', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn(baseTurnInput)
|
||||
|
||||
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
|
||||
|
||||
expect(service.inspect('session-1')).toMatchObject({
|
||||
status: 'idle',
|
||||
pendingMessageCount: 0,
|
||||
lastTerminalStatus: 'success'
|
||||
})
|
||||
})
|
||||
|
||||
it('hands an idle session with a resume token to the driver onSessionIdle hook', () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
const onSessionIdle = vi.fn()
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn(),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([]),
|
||||
onSessionIdle
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn(baseTurnInput)
|
||||
getEntry(service).lastResumeToken = 'resume-1'
|
||||
|
||||
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(onSessionIdle).toHaveBeenCalledWith('session-1')
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
} finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
it('does not call onSessionIdle for an idle session without a resume token', () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
const onSessionIdle = vi.fn()
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn(),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([]),
|
||||
onSessionIdle
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn(baseTurnInput)
|
||||
|
||||
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(onSessionIdle).not.toHaveBeenCalled()
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
} finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
|
||||
it('reuses an idle runtime for the next fresh turn', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const first = service.beginTurn(baseTurnInput)
|
||||
const entry = getEntry(service)
|
||||
const connection = { close: vi.fn(), send: vi.fn(), events: [] }
|
||||
entry.lastResumeToken = 'resume-1'
|
||||
entry.connection = connection
|
||||
|
||||
void terminalListener(first).onDone({ status: 'success', isTopicDone: true })
|
||||
const second = service.beginTurn({
|
||||
...baseTurnInput,
|
||||
assistantMessageId: 'assistant-2',
|
||||
userMessage: userMessage('user-2')
|
||||
})
|
||||
|
||||
expect(second).not.toBe(first)
|
||||
expect(getEntry(service).connection).toBe(connection)
|
||||
expect(getEntry(service).pendingTurns).toEqual([])
|
||||
expect(service.inspect('session-1')).toMatchObject({
|
||||
assistantMessageId: 'assistant-2',
|
||||
status: 'active',
|
||||
pendingMessageCount: 0,
|
||||
resumeToken: 'resume-1'
|
||||
})
|
||||
})
|
||||
|
||||
it('ignores per-execution terminal events until the topic is done', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn(baseTurnInput)
|
||||
|
||||
void terminalListener(handle).onPaused({ status: 'paused', isTopicDone: false })
|
||||
|
||||
expect(service.inspect('session-1')).toMatchObject({
|
||||
status: 'active',
|
||||
lastTerminalStatus: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('clears the runtime and closes the connection on closeSession', () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
const connection = { close: vi.fn(), send: vi.fn(), events: [] }
|
||||
const entry = getEntry(service)
|
||||
entry.connection = connection
|
||||
entry.connectionLoop = Promise.resolve()
|
||||
entry.startingNextTurn = true
|
||||
|
||||
service.closeSession('session-1')
|
||||
|
||||
expect(connection.close).toHaveBeenCalled()
|
||||
expect(entry.connection).toBeUndefined()
|
||||
expect(entry.connectionLoop).toBeUndefined()
|
||||
expect(entry.currentTurn).toBeUndefined()
|
||||
expect(entry.startingNextTurn).toBe(false)
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not throw and logs a warning when the connection close rejects on closeSession (REGRESSION agent-session-5)', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
const closeError = new Error('close failed')
|
||||
const connection = { close: vi.fn().mockRejectedValue(closeError), send: vi.fn(), events: [] }
|
||||
const entry = getEntry(service)
|
||||
entry.connection = connection
|
||||
entry.connectionLoop = Promise.resolve()
|
||||
|
||||
expect(() => service.closeSession('session-1')).not.toThrow()
|
||||
|
||||
expect(connection.close).toHaveBeenCalled()
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
await vi.waitFor(() =>
|
||||
expect(mockMainLoggerService.warn).toHaveBeenCalledWith(
|
||||
'Agent runtime connection close failed',
|
||||
expect.objectContaining({ sessionId: 'session-1', error: closeError })
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
it('persists assistant turns with the latest resume token', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
getEntry(service).lastResumeToken = 'resume-1'
|
||||
|
||||
await persistenceListener(handle).onDone({
|
||||
status: 'success',
|
||||
isTopicDone: true,
|
||||
finalMessage: { id: 'assistant-1', role: 'assistant', parts: [{ type: 'text', text: 'hi' }] }
|
||||
})
|
||||
|
||||
expect(mocks.saveMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
runtimeResumeToken: 'resume-1',
|
||||
message: {
|
||||
id: 'assistant-1',
|
||||
role: 'assistant',
|
||||
status: 'success',
|
||||
data: { parts: [{ type: 'text', text: 'hi' }] },
|
||||
modelId: 'claude-code::claude-sonnet-4-5'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('routes runtime events from the selected driver into the active turn', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
interrupt: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
const connect = vi.fn().mockResolvedValue(connection)
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect,
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
events.push({ type: 'resume-token', token: 'resume-1' })
|
||||
await vi.waitFor(() => expect(service.inspect('session-1')).toMatchObject({ resumeToken: 'resume-1' }))
|
||||
|
||||
events.push({ type: 'chunk', chunk: { type: 'text-delta', id: 'text-1', delta: 'hello' } })
|
||||
await expect(reader.read()).resolves.toMatchObject({
|
||||
value: { type: 'text-delta', id: 'text-1', delta: 'hello' },
|
||||
done: false
|
||||
})
|
||||
|
||||
events.push({ type: 'turn-complete' })
|
||||
await expect(reader.read()).resolves.toMatchObject({ done: true })
|
||||
})
|
||||
|
||||
it('surfaces a runtime error event via controller.error and drops trailing chunks (REGRESSION agent-session-3)', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
interrupt: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
const connect = vi.fn().mockResolvedValue(connection)
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect,
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalled())
|
||||
|
||||
// A runtime `error` event surfaces through the active turn's controller.
|
||||
events.push({ type: 'error', error: new Error('runtime boom') })
|
||||
await expect(reader.read()).rejects.toThrow('runtime boom')
|
||||
|
||||
// The turn is marked terminal synchronously, so a trailing chunk in the same connection
|
||||
// loop is dropped instead of being enqueued on the now-errored controller (which would throw).
|
||||
await vi.waitFor(() => expect(getEntry(service).currentTurn?.terminalStatus).toBe('error'))
|
||||
events.push({ type: 'chunk', chunk: { type: 'text-delta', id: 't', delta: 'late' } })
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
expect(getEntry(service).currentTurn?.terminalStatus).toBe('error')
|
||||
})
|
||||
|
||||
it('passes trace context to the runtime driver and closes the connection after trace turns', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
shouldCloseAfterTurn: () => true,
|
||||
close: vi.fn()
|
||||
}
|
||||
const connect = vi.fn().mockResolvedValue(connection)
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect,
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({
|
||||
...baseTurnInput,
|
||||
userMessage: userMessage('user-1'),
|
||||
traceId: '0'.repeat(32),
|
||||
rootSpanId: '1'.repeat(16)
|
||||
})
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() =>
|
||||
expect(connect).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
modelId: 'claude-code::claude-sonnet-4-5',
|
||||
resumeToken: undefined,
|
||||
trace: {
|
||||
topicId: 'agent-session:session-1',
|
||||
traceId: '0'.repeat(32),
|
||||
rootSpanId: '1'.repeat(16),
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
modelName: 'claude-sonnet-4-5'
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
|
||||
|
||||
expect(connection.close).toHaveBeenCalledOnce()
|
||||
expect(getEntry(service).connection).toBeUndefined()
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
it('hydrates the persisted resume token before connecting a cold historical session', async () => {
|
||||
mocks.getLastRuntimeResumeToken.mockResolvedValue('resume-db')
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
const connect = vi.fn().mockResolvedValue(connection)
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect,
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() =>
|
||||
expect(connect).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
modelId: 'claude-code::claude-sonnet-4-5',
|
||||
resumeToken: 'resume-db',
|
||||
trace: undefined
|
||||
})
|
||||
)
|
||||
|
||||
expect(mocks.getLastRuntimeResumeToken).toHaveBeenCalledWith('session-1')
|
||||
expect(service.inspect('session-1')).toMatchObject({ resumeToken: 'resume-db' })
|
||||
service.closeSession('session-1')
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
it('closes the runtime session when the active turn is aborted by the user', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn().mockResolvedValue(connection),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const controller = new AbortController()
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: controller.signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
controller.abort('user-requested')
|
||||
|
||||
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
it('closes a late runtime connection when the user aborts before connect resolves', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
const pendingConnection = createDeferred<typeof connection>()
|
||||
const connect = vi.fn().mockReturnValue(pendingConnection.promise)
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect,
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const controller = new AbortController()
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: controller.signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connect).toHaveBeenCalledOnce())
|
||||
|
||||
controller.abort('user-requested')
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
|
||||
pendingConnection.resolve(connection)
|
||||
|
||||
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
|
||||
expect(connection.send).not.toHaveBeenCalled()
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
describe('interrupt-when-safe — live follow-up', () => {
|
||||
it('defers the interrupt while a tool is mid-flight, then fires once the tool settles', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
interrupt: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn()
|
||||
}
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn().mockResolvedValue(connection),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
// A tool is now in flight — the turn is not safe to interrupt.
|
||||
events.push({ type: 'chunk', chunk: { type: 'tool-input-start', toolCallId: 'tool-1' } })
|
||||
await vi.waitFor(() => expect(getEntry(service).currentTurn.activeToolIds.has('tool-1')).toBe(true))
|
||||
|
||||
// The follow-up queues but must NOT interrupt while the tool runs.
|
||||
service.enqueueUserMessage('session-1', userMessage('user-2'))
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
expect(connection.interrupt).not.toHaveBeenCalled()
|
||||
expect(mocks.pauseRuntimeTurn).not.toHaveBeenCalled()
|
||||
|
||||
// Tool settles → now safe → interrupt fires and the runtime turn is paused.
|
||||
events.push({ type: 'chunk', chunk: { type: 'tool-output-available', toolCallId: 'tool-1' } })
|
||||
await vi.waitFor(() => expect(connection.interrupt).toHaveBeenCalledOnce())
|
||||
expect(mocks.pauseRuntimeTurn).toHaveBeenCalledWith('agent-session:session-1', 'agent-runtime-interrupt')
|
||||
|
||||
service.closeSession('session-1')
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
it('interrupts immediately on the next microtask when no tool is active', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
interrupt: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn()
|
||||
}
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn().mockResolvedValue(connection),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: new AbortController().signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
// No tool in flight (activeToolIds empty) → the queued follow-up interrupts on the next microtask.
|
||||
expect(getEntry(service).currentTurn.activeToolIds.size).toBe(0)
|
||||
service.enqueueUserMessage('session-1', userMessage('user-2'))
|
||||
expect(connection.interrupt).not.toHaveBeenCalled()
|
||||
|
||||
await vi.waitFor(() => expect(connection.interrupt).toHaveBeenCalledOnce())
|
||||
expect(mocks.pauseRuntimeTurn).toHaveBeenCalledWith('agent-session:session-1', 'agent-runtime-interrupt')
|
||||
|
||||
service.closeSession('session-1')
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps the runtime session alive when a steer interrupt pauses the turn', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn().mockResolvedValue(connection),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const controller = new AbortController()
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: controller.signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
// The steer path marks the turn before aborting; the abort reason is irrelevant.
|
||||
getEntry(service).currentTurn.interruptRequested = true
|
||||
controller.abort()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ done: true })
|
||||
expect(connection.close).not.toHaveBeenCalled()
|
||||
expect(service.inspect('session-1')).toMatchObject({
|
||||
sessionId: 'session-1',
|
||||
status: 'active'
|
||||
})
|
||||
service.closeSession('session-1')
|
||||
})
|
||||
|
||||
it('tears the session down on abort with an interrupt-looking reason when none was requested', async () => {
|
||||
const events = createAsyncQueue<any>()
|
||||
const connection = {
|
||||
events: events.iterable,
|
||||
send: vi.fn(),
|
||||
close: vi.fn()
|
||||
}
|
||||
runtimeDriverRegistry.register({
|
||||
type: 'test-runtime',
|
||||
capabilities: ['agent-session'],
|
||||
connect: vi.fn().mockResolvedValue(connection),
|
||||
validateSession: vi.fn(),
|
||||
listAvailableTools: vi.fn().mockResolvedValue([])
|
||||
})
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
const controller = new AbortController()
|
||||
const stream = service.openTurnStream({
|
||||
sessionId: 'session-1',
|
||||
turnId: handle.turnId,
|
||||
signal: controller.signal
|
||||
})
|
||||
const reader = stream.getReader()
|
||||
|
||||
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
|
||||
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
|
||||
|
||||
// Reason matches the old interrupt sentinel, but no interrupt was requested —
|
||||
// teardown is driven by `interruptRequested`, not the signal reason.
|
||||
controller.abort('agent-runtime-interrupt')
|
||||
|
||||
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
|
||||
expect(service.inspect('session-1')).toBeUndefined()
|
||||
await reader.cancel().catch(() => undefined)
|
||||
})
|
||||
|
||||
it('persists errored assistant turns with the latest resume token', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
|
||||
getEntry(service).lastResumeToken = 'resume-init'
|
||||
|
||||
await persistenceListener(handle).onError({
|
||||
status: 'error',
|
||||
isTopicDone: true,
|
||||
error: { name: 'Error', message: 'boom' },
|
||||
finalMessage: { id: 'assistant-1', role: 'assistant', parts: [] }
|
||||
})
|
||||
|
||||
expect(mocks.saveMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
runtimeResumeToken: 'resume-init',
|
||||
message: {
|
||||
id: 'assistant-1',
|
||||
role: 'assistant',
|
||||
status: 'error',
|
||||
data: { parts: [{ type: 'data-error', data: { name: 'Error', message: 'boom' } }] },
|
||||
modelId: 'claude-code::claude-sonnet-4-5'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('starts queued turns with runtime request metadata and assistant seed', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
const entry = getEntry(service)
|
||||
entry.lastResumeToken = 'resume-1'
|
||||
entry.currentTurn.activeToolIds.add('tool-1')
|
||||
entry.pendingTurns.push(userMessage('user-2'))
|
||||
|
||||
await (service as any).startNextTurn(entry)
|
||||
|
||||
const savedMessage = mocks.saveMessage.mock.calls[0][0].message
|
||||
expect(mocks.saveMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
message: {
|
||||
role: 'assistant',
|
||||
status: 'pending',
|
||||
data: { parts: [] },
|
||||
modelId: 'claude-code::claude-sonnet-4-5',
|
||||
traceId: expect.any(String)
|
||||
}
|
||||
})
|
||||
expect(mocks.spanCacheSetTopicId).toHaveBeenCalledWith(savedMessage.traceId, 'agent-session:session-1')
|
||||
expect(mocks.startRuntimeTurn).toHaveBeenCalledWith({
|
||||
topicId: 'agent-session:session-1',
|
||||
modelId: 'claude-code::claude-sonnet-4-5',
|
||||
rootSpan: expect.anything(),
|
||||
request: {
|
||||
chatId: 'agent-session:session-1',
|
||||
trigger: 'submit-message',
|
||||
messageId: 'generated-message-id',
|
||||
messages: [
|
||||
{ id: 'user-2', role: 'user', parts: [{ type: 'text', text: 'hello' }] },
|
||||
{ id: 'generated-message-id', role: 'assistant', parts: [] }
|
||||
],
|
||||
runtime: { kind: 'agent-session', sessionId: 'session-1', turnId: expect.any(String) }
|
||||
},
|
||||
listeners: [
|
||||
expect.objectContaining({ id: expect.stringContaining('persistence:agents-db:') }),
|
||||
expect.objectContaining({ id: 'agent-runtime:session-1' }),
|
||||
expect.objectContaining({ id: 'persistence:trace:agent-session:session-1' })
|
||||
]
|
||||
})
|
||||
const request = mocks.startRuntimeTurn.mock.calls[0][0].request
|
||||
expect(request.messageId).toBe(request.messages[1].id)
|
||||
expect(getEntry(service).currentTurn.trace).toMatchObject({
|
||||
topicId: 'agent-session:session-1',
|
||||
traceId: savedMessage.traceId,
|
||||
rootSpanId: expect.any(String),
|
||||
sessionId: 'session-1',
|
||||
turnId: request.runtime.turnId,
|
||||
modelName: 'claude-sonnet-4-5'
|
||||
})
|
||||
})
|
||||
|
||||
it('surfaces the error and settles the turn when the next-turn placeholder save rejects (R3)', async () => {
|
||||
const service = new AgentSessionRuntimeService()
|
||||
service.beginTurn(baseTurnInput)
|
||||
const entry = getEntry(service)
|
||||
const queued = userMessage('user-2')
|
||||
entry.pendingTurns.push(queued)
|
||||
|
||||
const saveError = new Error('db down')
|
||||
mocks.saveMessage.mockRejectedValueOnce(saveError)
|
||||
|
||||
// The placeholder save failed: re-queuing would just fail again and the idle TTL would
|
||||
// silently clear it, so the message is dropped, the failure is surfaced to the live renderer,
|
||||
// and the turn is settled to `error` (not left silently idle).
|
||||
await expect((service as any).startNextTurn(entry)).resolves.toBeUndefined()
|
||||
|
||||
expect(entry.pendingTurns).toEqual([])
|
||||
expect(mocks.startRuntimeTurn).not.toHaveBeenCalled()
|
||||
expect(mocks.broadcastTopicError).toHaveBeenCalledWith(
|
||||
entry.topicId,
|
||||
entry.modelId,
|
||||
expect.objectContaining({ message: expect.stringContaining('db down') })
|
||||
)
|
||||
expect(entry.status).toBe('idle')
|
||||
expect(entry.lastTerminalStatus).toBe('error')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Agent-session DB backend — writes assistant turns to the `agent_session_message`
|
||||
* table via `agentSessionMessageService`. The user message is persisted
|
||||
* by AgentChatContextProvider before streaming starts (not here).
|
||||
*
|
||||
* The listener folds any error into `finalMessage.parts` upstream, so a
|
||||
* single `persistAssistant` handles success / paused / error uniformly.
|
||||
*/
|
||||
|
||||
import { agentSessionMessageService } from '@data/services/AgentSessionMessageService'
|
||||
import type { CherryMessagePart, CherryUIMessage } from '@shared/data/types/message'
|
||||
import type { UniqueModelId } from '@shared/data/types/model'
|
||||
import { v7 as uuidv7 } from 'uuid'
|
||||
|
||||
import {
|
||||
finalizeInterruptedParts,
|
||||
type PersistAssistantInput,
|
||||
type PersistenceBackend
|
||||
} from '../../streamManager/persistence/PersistenceBackend'
|
||||
|
||||
export interface AgentSessionMessageBackendOptions {
|
||||
/** Cherry Studio agent-session id. */
|
||||
sessionId: string
|
||||
/** Model id used for this assistant message. */
|
||||
modelId?: UniqueModelId
|
||||
/** Opaque runtime resume token persisted for future recovery; `undefined` when unknown. */
|
||||
runtimeResumeToken?: string | (() => string | undefined)
|
||||
/** Post-success hook — typically session auto-rename. */
|
||||
afterPersist?: (finalMessage: CherryUIMessage) => Promise<void>
|
||||
}
|
||||
|
||||
export class AgentSessionMessageBackend implements PersistenceBackend {
|
||||
readonly kind = 'agents-db'
|
||||
readonly afterPersist?: (finalMessage: CherryUIMessage) => Promise<void>
|
||||
|
||||
constructor(private readonly opts: AgentSessionMessageBackendOptions) {
|
||||
this.afterPersist = opts.afterPersist
|
||||
}
|
||||
|
||||
async persistAssistant(input: PersistAssistantInput): Promise<void> {
|
||||
const { finalMessage, status, stats } = input
|
||||
const parts = finalizeInterruptedParts((finalMessage?.parts ?? []) as CherryMessagePart[], status)
|
||||
const runtimeResumeToken = this.getRuntimeResumeToken()
|
||||
await agentSessionMessageService.saveMessage({
|
||||
sessionId: this.opts.sessionId,
|
||||
...(runtimeResumeToken ? { runtimeResumeToken } : {}),
|
||||
message: {
|
||||
id: finalMessage?.id ?? uuidv7(),
|
||||
role: 'assistant',
|
||||
status,
|
||||
data: { parts },
|
||||
modelId: this.opts.modelId,
|
||||
...(stats ? { stats } : {})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private getRuntimeResumeToken(): string | undefined {
|
||||
return typeof this.opts.runtimeResumeToken === 'function'
|
||||
? this.opts.runtimeResumeToken()
|
||||
: this.opts.runtimeResumeToken
|
||||
}
|
||||
}
|
||||
19
src/main/ai/agentSession/topic.ts
Normal file
19
src/main/ai/agentSession/topic.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
const AGENT_SESSION_PREFIX = 'agent-session:'
|
||||
|
||||
/** Check if a topicId represents an agent session (vs a normal chat). */
|
||||
export function isAgentSessionTopic(topicId: string): boolean {
|
||||
return topicId.startsWith(AGENT_SESSION_PREFIX)
|
||||
}
|
||||
|
||||
/** Extract the agent session ID from a topic ID. Throws if not an agent session topic. */
|
||||
export function extractAgentSessionId(topicId: string): string {
|
||||
if (!isAgentSessionTopic(topicId)) {
|
||||
throw new Error(`Not an agent session topicId: ${topicId}`)
|
||||
}
|
||||
return topicId.slice(AGENT_SESSION_PREFIX.length)
|
||||
}
|
||||
|
||||
/** Build the topic id for an agent session. */
|
||||
export function buildAgentSessionTopicId(sessionId: string): string {
|
||||
return `${AGENT_SESSION_PREFIX}${sessionId}`
|
||||
}
|
||||
19
src/main/ai/agents/AgentJobsService.ts
Normal file
19
src/main/ai/agents/AgentJobsService.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { application } from '@main/core/application'
|
||||
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
|
||||
import { AgentTaskJobHandler } from './AgentTaskJobHandler'
|
||||
|
||||
@Injectable('AgentJobsService')
|
||||
@ServicePhase(Phase.WhenReady)
|
||||
@DependsOn(['JobManager'])
|
||||
export class AgentJobsService extends BaseService {
|
||||
protected async onInit(): Promise<void> {
|
||||
const jobManager = application.get('JobManager')
|
||||
jobManager.registerHandler('agent.task', AgentTaskJobHandler)
|
||||
|
||||
this.ipcHandle(IpcChannel.Ai_Agent_RunTask, async (_event, taskId: string) => {
|
||||
return jobManager.triggerJobScheduleNowById(taskId)
|
||||
})
|
||||
}
|
||||
}
|
||||
84
src/main/ai/agents/AgentTaskJobHandler.ts
Normal file
84
src/main/ai/agents/AgentTaskJobHandler.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Job handler for `agent.task` — scheduled agent prompts.
|
||||
*
|
||||
* Thin metadata + execute wrapper; business logic lives in `./runAgentTask`.
|
||||
* Failure backstop: after three consecutive failed terminal jobs on the same
|
||||
* schedule, pauses the schedule via `JobManager.pauseJobScheduleById`. The
|
||||
* `jobTable` rows are the single source of truth — no in-memory counter
|
||||
* (the legacy `SchedulerService.consecutiveErrors` map reset on every process
|
||||
* restart, making the breaker effectively unreachable in practice).
|
||||
*/
|
||||
|
||||
import { application } from '@application'
|
||||
import { jobService } from '@data/services/JobService'
|
||||
import { loggerService } from '@logger'
|
||||
import type { JobHandler } from '@main/core/job/types'
|
||||
|
||||
import { runAgentTask } from './runAgentTask'
|
||||
|
||||
declare module '@main/core/job/jobRegistry' {
|
||||
interface JobRegistry {
|
||||
'agent.task': {
|
||||
agentId: string
|
||||
prompt: string
|
||||
/** Per-task timeout in minutes. Enforced inside `runAgentTask`; handler-level
|
||||
* `defaultTimeoutMs` is intentionally unset so each task may set its own value. */
|
||||
timeoutMinutes: number
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const logger = loggerService.withContext('AgentTaskJobHandler')
|
||||
|
||||
const RECENT_TERMINAL_WINDOW = 3
|
||||
|
||||
export const AgentTaskJobHandler: JobHandler<{ agentId: string; prompt: string; timeoutMinutes: number }> = {
|
||||
/**
|
||||
* 'retry': non-terminal jobs from a previous run are re-pended on startup
|
||||
* so the recovered job dispatches against the latest agent configuration.
|
||||
* This matches the legacy poll-loop semantics where a task missed by a
|
||||
* crash was simply picked up on the next 60s tick.
|
||||
*/
|
||||
recovery: 'retry',
|
||||
|
||||
/**
|
||||
* Per-agent serialization queue: a single agent never runs two scheduled
|
||||
* tasks concurrently (Claude Code subprocess + workspace state would
|
||||
* collide). Cross-agent parallelism is unaffected.
|
||||
*/
|
||||
defaultQueue: (input) => `agent:${input.agentId}`,
|
||||
|
||||
defaultConcurrency: 1,
|
||||
|
||||
/**
|
||||
* Schedule-driven tasks do not retry inside the Job runtime — failure
|
||||
* surfaces to `onSettled` and the circuit breaker decides whether to pause.
|
||||
* Re-attempting an LLM call automatically is rarely helpful and can rack
|
||||
* up token spend without diagnostic value.
|
||||
*/
|
||||
defaultRetryPolicy: { maxAttempts: 1, backoff: 'none', baseDelayMs: 0, maxDelayMs: 0 },
|
||||
|
||||
async execute(ctx) {
|
||||
return await runAgentTask(ctx)
|
||||
},
|
||||
|
||||
async onSettled(event) {
|
||||
if (event.status !== 'failed' || !event.scheduleId) return
|
||||
|
||||
const recent = await jobService.listRecentTerminalByScheduleId(event.scheduleId, RECENT_TERMINAL_WINDOW)
|
||||
if (recent.length < RECENT_TERMINAL_WINDOW) return
|
||||
if (!recent.every((j) => j.status === 'failed')) return
|
||||
|
||||
logger.warn('Agent task schedule failed in last N terminal runs — pausing', {
|
||||
scheduleId: event.scheduleId,
|
||||
window: RECENT_TERMINAL_WINDOW
|
||||
})
|
||||
try {
|
||||
await application.get('JobManager').pauseJobScheduleById(event.scheduleId)
|
||||
} catch (err) {
|
||||
logger.error('Failed to pause schedule after consecutive failures', err as Error, {
|
||||
scheduleId: event.scheduleId
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
176
src/main/ai/agents/__tests__/AgentTaskJobHandler.test.ts
Normal file
176
src/main/ai/agents/__tests__/AgentTaskJobHandler.test.ts
Normal file
@@ -0,0 +1,176 @@
|
||||
import type { JobContext, JobSettledEvent } from '@main/core/job/types'
|
||||
import type { JobSnapshot } from '@shared/data/api/schemas/jobs'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@application', async () => {
|
||||
const mod = await import('@test-mocks/main/application')
|
||||
return mod.mockApplicationFactory()
|
||||
})
|
||||
|
||||
vi.mock('@data/services/JobService', () => ({
|
||||
jobService: {
|
||||
listRecentTerminalByScheduleId: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../runAgentTask', () => ({
|
||||
runAgentTask: vi.fn()
|
||||
}))
|
||||
|
||||
import { application } from '@application'
|
||||
import { jobService } from '@data/services/JobService'
|
||||
|
||||
import { AgentTaskJobHandler } from '../AgentTaskJobHandler'
|
||||
import { runAgentTask } from '../runAgentTask'
|
||||
|
||||
function makeTerminal(status: 'completed' | 'failed' | 'cancelled', id = `job-${status}`): JobSnapshot {
|
||||
return {
|
||||
id,
|
||||
type: 'agent.task',
|
||||
status,
|
||||
priority: 0,
|
||||
queue: 'agent:a1',
|
||||
idempotencyKey: null,
|
||||
scheduleId: 's1',
|
||||
scheduledAt: '2026-05-20T00:00:00.000Z',
|
||||
startedAt: '2026-05-20T00:00:01.000Z',
|
||||
finishedAt: '2026-05-20T00:00:02.000Z',
|
||||
attempt: 0,
|
||||
maxAttempts: 1,
|
||||
input: {},
|
||||
output: null,
|
||||
error: null,
|
||||
parentId: null,
|
||||
cancelRequested: false,
|
||||
metadata: {},
|
||||
timeoutMs: null,
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:02.000Z'
|
||||
}
|
||||
}
|
||||
|
||||
function makeSettled(overrides: Partial<JobSettledEvent>): JobSettledEvent {
|
||||
return {
|
||||
jobId: 'job-1',
|
||||
type: 'agent.task',
|
||||
scheduleId: 's1',
|
||||
status: 'failed',
|
||||
output: null,
|
||||
error: { code: 'TEST', message: 'boom', retryable: false },
|
||||
attempt: 0,
|
||||
...overrides
|
||||
} as JobSettledEvent
|
||||
}
|
||||
|
||||
describe('AgentTaskJobHandler', () => {
|
||||
const pauseSpy = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.mocked(application.get).mockImplementation((name: string) => {
|
||||
if (name === 'JobManager') return { pauseJobScheduleById: pauseSpy } as never
|
||||
throw new Error(`Unexpected application.get('${name}')`)
|
||||
})
|
||||
pauseSpy.mockReset()
|
||||
pauseSpy.mockResolvedValue(true)
|
||||
vi.mocked(jobService.listRecentTerminalByScheduleId).mockReset()
|
||||
vi.mocked(runAgentTask).mockReset()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('metadata', () => {
|
||||
it('declares per-agent queue + concurrency 1 + retry-once policy', () => {
|
||||
expect(AgentTaskJobHandler.recovery).toBe('retry')
|
||||
expect(AgentTaskJobHandler.defaultConcurrency).toBe(1)
|
||||
expect(AgentTaskJobHandler.defaultRetryPolicy).toEqual({
|
||||
maxAttempts: 1,
|
||||
backoff: 'none',
|
||||
baseDelayMs: 0,
|
||||
maxDelayMs: 0
|
||||
})
|
||||
expect(AgentTaskJobHandler.defaultQueue?.({ agentId: 'a-42', prompt: 'x', timeoutMinutes: 2 })).toBe('agent:a-42')
|
||||
})
|
||||
})
|
||||
|
||||
describe('execute', () => {
|
||||
it('delegates to runAgentTask with the JobContext', async () => {
|
||||
vi.mocked(runAgentTask).mockResolvedValueOnce({ sessionId: 'sess-1', result: 'ok' })
|
||||
const ctx = { jobId: 'j1', input: { agentId: 'a', prompt: 'p', timeoutMinutes: 2 } } as JobContext<{
|
||||
agentId: string
|
||||
prompt: string
|
||||
timeoutMinutes: number
|
||||
}>
|
||||
|
||||
const out = await AgentTaskJobHandler.execute(ctx)
|
||||
|
||||
expect(out).toEqual({ sessionId: 'sess-1', result: 'ok' })
|
||||
expect(runAgentTask).toHaveBeenCalledWith(ctx)
|
||||
})
|
||||
})
|
||||
|
||||
describe('onSettled circuit breaker', () => {
|
||||
it('pauses schedule after 3 consecutive failures', async () => {
|
||||
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
|
||||
makeTerminal('failed', 'a'),
|
||||
makeTerminal('failed', 'b'),
|
||||
makeTerminal('failed', 'c')
|
||||
])
|
||||
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
|
||||
|
||||
expect(jobService.listRecentTerminalByScheduleId).toHaveBeenCalledWith('s1', 3)
|
||||
expect(pauseSpy).toHaveBeenCalledWith('s1')
|
||||
})
|
||||
|
||||
it('does not pause when the latest is failed but a recent one is completed', async () => {
|
||||
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
|
||||
makeTerminal('failed', 'a'),
|
||||
makeTerminal('completed', 'b'),
|
||||
makeTerminal('failed', 'c')
|
||||
])
|
||||
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
|
||||
|
||||
expect(pauseSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not pause when the recent-terminal window is not yet full', async () => {
|
||||
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
|
||||
makeTerminal('failed', 'a'),
|
||||
makeTerminal('failed', 'b')
|
||||
])
|
||||
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
|
||||
|
||||
expect(pauseSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not act on non-failed terminal events', async () => {
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'completed' }))
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'cancelled' }))
|
||||
|
||||
expect(jobService.listRecentTerminalByScheduleId).not.toHaveBeenCalled()
|
||||
expect(pauseSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not act when the failed job has no scheduleId (ad-hoc enqueue)', async () => {
|
||||
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed', scheduleId: null }))
|
||||
|
||||
expect(jobService.listRecentTerminalByScheduleId).not.toHaveBeenCalled()
|
||||
expect(pauseSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('swallows pauseJobScheduleById errors so onSettled cannot throw', async () => {
|
||||
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
|
||||
makeTerminal('failed', 'a'),
|
||||
makeTerminal('failed', 'b'),
|
||||
makeTerminal('failed', 'c')
|
||||
])
|
||||
pauseSpy.mockRejectedValueOnce(new Error('db lost'))
|
||||
|
||||
await expect(AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))).resolves.not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
353
src/main/ai/agents/__tests__/runAgentTask.test.ts
Normal file
353
src/main/ai/agents/__tests__/runAgentTask.test.ts
Normal file
@@ -0,0 +1,353 @@
|
||||
/**
|
||||
* Phase 1 coverage: focuses on the pure branches that do not engage the
|
||||
* Claude Code subprocess (heartbeat skip + agent-not-found). The full
|
||||
* streaming path is exercised by integration tests / Phase 5 manual e2e.
|
||||
*
|
||||
* Each fire creates a fresh session — there is no cross-fire session reuse.
|
||||
*/
|
||||
|
||||
import type { JobContext } from '@main/core/job/types'
|
||||
import type { AgentEntity } from '@shared/data/api/schemas/agents'
|
||||
import type { JobSnapshot } from '@shared/data/api/schemas/jobs'
|
||||
import type { AgentSessionEntity } from '@shared/data/api/schemas/sessions'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { mockAbort, mockGetAdapter, mockStartRun, captured } = vi.hoisted(() => {
|
||||
const captured: { listeners: Array<Record<string, (arg?: unknown) => void>> } = { listeners: [] }
|
||||
return {
|
||||
mockAbort: vi.fn(),
|
||||
mockGetAdapter: vi.fn(() => undefined),
|
||||
mockStartRun: vi.fn(async (opts: { listeners: typeof captured.listeners }) => {
|
||||
captured.listeners = opts.listeners
|
||||
}),
|
||||
captured
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@application', async () => {
|
||||
const mod = await import('@test-mocks/main/application')
|
||||
return mod.mockApplicationFactory({
|
||||
// ChannelManager + AiStreamManager aren't in the default mock service set; the
|
||||
// streaming path (post heartbeat-skip) reads both, so wire minimal stubs here.
|
||||
ChannelManager: { getAdapter: mockGetAdapter },
|
||||
AiStreamManager: { abort: mockAbort }
|
||||
} as never)
|
||||
})
|
||||
|
||||
vi.mock('@main/ai/streamManager/api/startAgentSessionRun', () => ({
|
||||
startAgentSessionRun: mockStartRun
|
||||
}))
|
||||
|
||||
vi.mock('@data/services/AgentChannelService', () => ({
|
||||
agentChannelService: { getSubscribedChannels: vi.fn() }
|
||||
}))
|
||||
vi.mock('@data/services/AgentService', () => ({
|
||||
agentService: { getAgent: vi.fn() }
|
||||
}))
|
||||
vi.mock('@data/services/SessionService', () => ({
|
||||
sessionService: { createSession: vi.fn(), findAgentWorkspacePath: vi.fn() }
|
||||
}))
|
||||
vi.mock('@data/services/JobScheduleService', () => ({
|
||||
jobScheduleService: { getById: vi.fn() }
|
||||
}))
|
||||
vi.mock('@data/services/JobService', () => ({
|
||||
jobService: { getById: vi.fn() }
|
||||
}))
|
||||
vi.mock('@main/ai/agents/cherryclaw/heartbeat', () => ({
|
||||
readHeartbeat: vi.fn()
|
||||
}))
|
||||
|
||||
import { agentChannelService } from '@data/services/AgentChannelService'
|
||||
import { agentService } from '@data/services/AgentService'
|
||||
import { jobScheduleService } from '@data/services/JobScheduleService'
|
||||
import { jobService } from '@data/services/JobService'
|
||||
import { sessionService } from '@data/services/SessionService'
|
||||
import { readHeartbeat } from '@main/ai/agents/cherryclaw/heartbeat'
|
||||
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
|
||||
|
||||
import { runAgentTask } from '../runAgentTask'
|
||||
|
||||
function makeJobSnapshot(scheduleId: string | null = 's1'): JobSnapshot {
|
||||
return {
|
||||
id: 'j1',
|
||||
type: 'agent.task',
|
||||
status: 'running',
|
||||
priority: 0,
|
||||
queue: 'agent:a1',
|
||||
idempotencyKey: null,
|
||||
scheduleId,
|
||||
scheduledAt: '2026-05-20T00:00:00.000Z',
|
||||
startedAt: '2026-05-20T00:00:00.000Z',
|
||||
finishedAt: null,
|
||||
attempt: 0,
|
||||
maxAttempts: 1,
|
||||
input: {},
|
||||
output: null,
|
||||
error: null,
|
||||
parentId: null,
|
||||
cancelRequested: false,
|
||||
metadata: {},
|
||||
timeoutMs: null,
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:00.000Z'
|
||||
}
|
||||
}
|
||||
|
||||
function makeCtx(overrides: Partial<JobContext<{ agentId: string; prompt: string; timeoutMinutes: number }>> = {}) {
|
||||
return {
|
||||
jobId: 'j1',
|
||||
input: { agentId: 'a1', prompt: '__heartbeat__', timeoutMinutes: 2 },
|
||||
attempt: 0,
|
||||
signal: new AbortController().signal,
|
||||
metadata: {},
|
||||
patchMetadata: vi.fn(),
|
||||
reportProgress: vi.fn(),
|
||||
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() } as never,
|
||||
...overrides
|
||||
} as JobContext<{ agentId: string; prompt: string; timeoutMinutes: number }>
|
||||
}
|
||||
|
||||
function makeAgent(config: Record<string, unknown> = {}): AgentEntity {
|
||||
return {
|
||||
id: 'a1',
|
||||
type: 'claude-code',
|
||||
name: 'Agent A',
|
||||
model: 'sonnet' as never,
|
||||
configuration: config as never,
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:00.000Z',
|
||||
modelName: null
|
||||
}
|
||||
}
|
||||
|
||||
function makeSession(workspacePath: string | null = '/ws/a'): AgentSessionEntity {
|
||||
return {
|
||||
id: 'sess-new',
|
||||
agentId: 'a1',
|
||||
name: 'Scheduled task',
|
||||
workspaceId: workspacePath ? 'ws-1' : null,
|
||||
workspace: workspacePath
|
||||
? {
|
||||
id: 'ws-1',
|
||||
name: 'ws',
|
||||
path: workspacePath,
|
||||
orderKey: 'k',
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:00.000Z'
|
||||
}
|
||||
: null,
|
||||
orderKey: 'k',
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:00.000Z'
|
||||
} as AgentSessionEntity
|
||||
}
|
||||
|
||||
function makeSchedule(name: string | null = 'heartbeat') {
|
||||
return {
|
||||
id: 's1',
|
||||
type: 'agent.task',
|
||||
name,
|
||||
trigger: { kind: 'interval', ms: 60_000 },
|
||||
jobInputTemplate: {},
|
||||
enabled: true,
|
||||
nextRun: null,
|
||||
lastRun: null,
|
||||
catchUpPolicy: { kind: 'skip-missed' },
|
||||
metadata: {},
|
||||
createdAt: '2026-05-20T00:00:00.000Z',
|
||||
updatedAt: '2026-05-20T00:00:00.000Z'
|
||||
} as never
|
||||
}
|
||||
|
||||
describe('runAgentTask', () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(jobService.getById).mockReset()
|
||||
vi.mocked(jobScheduleService.getById).mockReset()
|
||||
vi.mocked(agentService.getAgent).mockReset()
|
||||
vi.mocked(sessionService.createSession).mockReset()
|
||||
vi.mocked(sessionService.findAgentWorkspacePath).mockReset()
|
||||
vi.mocked(readHeartbeat).mockReset()
|
||||
vi.mocked(agentChannelService.getSubscribedChannels).mockReset().mockResolvedValue([])
|
||||
mockStartRun.mockClear()
|
||||
mockAbort.mockClear()
|
||||
mockGetAdapter.mockClear()
|
||||
captured.listeners = []
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('throws when the agent cannot be found', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(null as never)
|
||||
|
||||
await expect(runAgentTask(makeCtx())).rejects.toThrow('Agent not found: a1')
|
||||
})
|
||||
|
||||
// A disabled heartbeat must short-circuit BEFORE createSession — that call also
|
||||
// lazily provisions a workspace on first fire, so creating a session for a fire
|
||||
// we're going to drop would accrete a session row (and workspace) every interval.
|
||||
it('skips a disabled heartbeat WITHOUT creating a session', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: false }))
|
||||
|
||||
const out = await runAgentTask(makeCtx())
|
||||
|
||||
expect(out).toEqual({ sessionId: null, result: 'Skipped (disabled)' })
|
||||
expect(sessionService.createSession).not.toHaveBeenCalled()
|
||||
expect(readHeartbeat).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips an enabled heartbeat with no workspace WITHOUT creating a session', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
|
||||
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce(null)
|
||||
|
||||
const out = await runAgentTask(makeCtx())
|
||||
|
||||
expect(out).toEqual({ sessionId: null, result: 'Skipped (no file)' })
|
||||
expect(sessionService.createSession).not.toHaveBeenCalled()
|
||||
expect(readHeartbeat).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips an enabled heartbeat with no heartbeat.md WITHOUT creating a session', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
|
||||
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce('/ws/a')
|
||||
vi.mocked(readHeartbeat).mockResolvedValueOnce(undefined)
|
||||
|
||||
const out = await runAgentTask(makeCtx())
|
||||
|
||||
expect(out).toEqual({ sessionId: null, result: 'Skipped (no file)' })
|
||||
expect(sessionService.createSession).not.toHaveBeenCalled()
|
||||
expect(readHeartbeat).toHaveBeenCalledWith('/ws/a')
|
||||
})
|
||||
|
||||
it('creates a session and runs when an enabled heartbeat has content', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
|
||||
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce('/ws/a')
|
||||
vi.mocked(readHeartbeat).mockResolvedValueOnce('check the inbox')
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
|
||||
|
||||
const promise = runAgentTask(makeCtx())
|
||||
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
|
||||
captured.listeners[0].onDone({ status: 'completed' })
|
||||
await promise
|
||||
|
||||
expect(readHeartbeat).toHaveBeenCalledWith('/ws/a')
|
||||
expect(sessionService.createSession).toHaveBeenCalledWith({ agentId: 'a1', name: 'heartbeat' })
|
||||
})
|
||||
|
||||
// C1 (agents-jobs-3): a `text-delta` chunk's payload is on `.delta`, not `.text`.
|
||||
// The previous `as { text }` cast silently accumulated nothing, so every run
|
||||
// persisted the `'Completed'` fallback instead of the model's reply.
|
||||
it('accumulates text-delta chunks via .delta into the result', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
|
||||
|
||||
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } }))
|
||||
|
||||
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
|
||||
const sentinel = captured.listeners[0]
|
||||
sentinel.onChunk({ type: 'text-delta', delta: 'Hello ' })
|
||||
sentinel.onChunk({ type: 'text-delta', delta: 'world' })
|
||||
sentinel.onChunk({ type: 'reasoning-delta', delta: 'ignored' })
|
||||
sentinel.onDone({ status: 'completed' })
|
||||
|
||||
const out = await promise
|
||||
expect(out).toEqual({ sessionId: 'sess-new', result: 'Hello world' })
|
||||
})
|
||||
|
||||
// agents-jobs-4: on a non-abort error, a subscribed channel must be notified exactly
|
||||
// once. The channel listener's generic `Error: …` is suppressed for task runs so only
|
||||
// the richer `[Task failed]` summary from notifyTaskError is delivered (no double-send).
|
||||
it('notifies a subscribed channel exactly once on a non-abort run error', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot('s1'))
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
|
||||
vi.mocked(agentChannelService.getSubscribedChannels).mockResolvedValueOnce([{ id: 'ch1' }] as never)
|
||||
|
||||
const adapter = {
|
||||
channelId: 'ch1',
|
||||
connected: true,
|
||||
notifyChatIds: ['chat-1'],
|
||||
sendMessage: vi.fn<(chatId: string, text: string) => Promise<void>>(async () => {}),
|
||||
onTextUpdate: vi.fn(async () => {}),
|
||||
onStreamComplete: vi.fn(async () => true)
|
||||
}
|
||||
mockGetAdapter.mockReturnValue(adapter as never)
|
||||
|
||||
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } }))
|
||||
|
||||
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
|
||||
// Simulate the stream manager dispatching the error to every listener (sentinel + channel).
|
||||
const errorResult = { error: new Error('boom'), status: 'error' }
|
||||
for (const listener of captured.listeners) {
|
||||
listener.onError?.(errorResult as never)
|
||||
}
|
||||
|
||||
await expect(promise).rejects.toThrow('boom')
|
||||
|
||||
// Exactly one channel message, and it's the task-framed summary — not the bare `Error: …`.
|
||||
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
|
||||
expect(adapter.sendMessage.mock.calls[0][1]).toContain('[Task failed]')
|
||||
expect(adapter.sendMessage.mock.calls[0][1]).not.toMatch(/^Error:/)
|
||||
})
|
||||
|
||||
// C2 (agents-jobs-1) + agents-jobs-7: aborting the run (JobManager cancel or
|
||||
// per-task timeout) must abort the upstream stream AND settle the handler
|
||||
// promise — otherwise it leaks until the JobManager force-finalize timeout.
|
||||
it('aborts the upstream stream and rejects when the run signal aborts', async () => {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
|
||||
|
||||
const controller = new AbortController()
|
||||
const promise = runAgentTask(
|
||||
makeCtx({ signal: controller.signal, input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } })
|
||||
)
|
||||
|
||||
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
|
||||
controller.abort(new Error('cancelled by manager'))
|
||||
|
||||
await expect(promise).rejects.toThrow('cancelled by manager')
|
||||
expect(mockAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-new'), 'cancelled by manager')
|
||||
})
|
||||
|
||||
// agents-jobs-5: a non-zero `timeoutMinutes` arms a per-task timeout timer in
|
||||
// makeRunSignal. When the stream never settles, the timer must fire, abort the
|
||||
// upstream stream, and reject the handler with the timeout error.
|
||||
it('aborts the upstream stream and rejects when the per-task timeout fires', async () => {
|
||||
vi.useFakeTimers()
|
||||
try {
|
||||
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
|
||||
|
||||
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 1 } }))
|
||||
const assertion = expect(promise).rejects.toThrow('Task timed out after 1 minute(s)')
|
||||
|
||||
// Flush the awaited setup chain (getById/getAgent/createSession/startRun) and
|
||||
// arm the timer, then advance past the 1-minute timeout so it fires. Never
|
||||
// settle the stream — the timeout is the only thing that resolves the run.
|
||||
await vi.advanceTimersByTimeAsync(60_000)
|
||||
|
||||
await assertion
|
||||
expect(mockAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-new'), 'Task timed out after 1 minute(s)')
|
||||
} finally {
|
||||
vi.useRealTimers()
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -62,7 +62,7 @@ export interface BuiltinAgentConfig {
|
||||
* Writes .claude/skills/ and .claude/plugins.json to the agent's
|
||||
* working directory so the SDK can auto-discover them.
|
||||
*
|
||||
* @param workspacePath - The agent's working directory (accessible_paths[0])
|
||||
* @param workspacePath - The agent session's workspace directory
|
||||
* @param builtinRole - The built-in role identifier ('assistant' or 'skill-creator')
|
||||
* @returns The parsed agent.json config, or undefined if not found
|
||||
*/
|
||||
@@ -14,11 +14,11 @@ vi.mock('node:fs/promises', () => ({
|
||||
|
||||
import { readdir, readFile, stat } from 'node:fs/promises'
|
||||
|
||||
import type { CherryClawConfiguration } from '@types'
|
||||
import type { AgentConfiguration } from '@shared/data/types/agent'
|
||||
|
||||
import { PromptBuilder } from '../prompt'
|
||||
|
||||
const baseConfig: CherryClawConfiguration = {
|
||||
const baseConfig: AgentConfiguration = {
|
||||
permission_mode: 'bypassPermissions',
|
||||
max_turns: 100,
|
||||
env_vars: {},
|
||||
@@ -2,7 +2,7 @@ import { readdir, readFile, stat } from 'node:fs/promises'
|
||||
import path from 'node:path'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { CherryClawConfiguration } from '@types'
|
||||
import type { AgentConfiguration } from '@shared/data/types/agent'
|
||||
|
||||
import { BOOTSTRAP_INSTRUCTIONS, SOUL_CONTENT_THRESHOLD } from './seedWorkspace'
|
||||
|
||||
@@ -142,15 +142,15 @@ ${sections}`
|
||||
* instructions. Returns a synchronous string — no I/O.
|
||||
*
|
||||
* Memory files layout (Soul Mode only):
|
||||
* {workspace}/soul.md — personality, tone, communication style
|
||||
* {workspace}/user.md — user profile, preferences, context
|
||||
* {workspace}/SOUL.md — personality, tone, communication style
|
||||
* {workspace}/USER.md — user profile, preferences, context
|
||||
* {workspace}/memory/FACT.md — durable project knowledge, technical decisions
|
||||
* {workspace}/memory/JOURNAL.jsonl — timestamped event log (managed by memory tool)
|
||||
*/
|
||||
export class PromptBuilder {
|
||||
private cache = new Map<string, CacheEntry>()
|
||||
|
||||
async buildSystemPrompt(workspacePath: string, config?: CherryClawConfiguration): Promise<string> {
|
||||
async buildSystemPrompt(workspacePath: string, config?: AgentConfiguration): Promise<string> {
|
||||
const parts: string[] = []
|
||||
|
||||
// Basic prompt: workspace system.md (case-insensitive) > embedded default
|
||||
@@ -227,7 +227,7 @@ ${content}
|
||||
* - If SOUL.md has substantial non-template content, skip (legacy agent migration).
|
||||
* - Otherwise, run bootstrap.
|
||||
*/
|
||||
private async shouldRunBootstrap(workspacePath: string, config?: CherryClawConfiguration): Promise<boolean> {
|
||||
private async shouldRunBootstrap(workspacePath: string, config?: AgentConfiguration): Promise<boolean> {
|
||||
if (config?.bootstrap_completed === true) {
|
||||
return false
|
||||
}
|
||||
@@ -63,7 +63,7 @@ Your goal in this conversation is to:
|
||||
- Rename yourself using \`mcp__claw__config\` (action: "rename", name: the chosen name)
|
||||
- Update \`SOUL.md\` with your role definition, personality, tone, principles, and boundaries using the Edit tool
|
||||
- Update \`USER.md\` with everything you learned about the user using the Edit tool
|
||||
- Log the bootstrap completion using \`mcp__claw__memory\` (append action, tag: "bootstrap")
|
||||
- Log the bootstrap completion using \`mcp__agent-memory__memory\` (append action, tag: "bootstrap")
|
||||
- Mark bootstrap as complete using \`mcp__claw__config\` (action: "complete_bootstrap")
|
||||
|
||||
Guidelines:
|
||||
251
src/main/ai/agents/runAgentTask.ts
Normal file
251
src/main/ai/agents/runAgentTask.ts
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* Business logic for `agent.task` jobs — owned by `AgentTaskJobHandler`.
|
||||
*
|
||||
* Each fire creates a fresh agent session. Per-fire sessions are recorded in
|
||||
* `job.output.sessionId` for audit only — there is no cross-fire session
|
||||
* reuse pointer on the schedule. Scheduled tasks are discrete background
|
||||
* invocations (heartbeat, periodic summary, polling), not conversations, so
|
||||
* carrying context across fires would only stuff the model's window with
|
||||
* stale state. Persistent agent memory belongs in workspace files
|
||||
* (`heartbeat.md`, agent memory) instead of session history.
|
||||
*/
|
||||
|
||||
import { agentChannelService } from '@data/services/AgentChannelService'
|
||||
import { agentService } from '@data/services/AgentService'
|
||||
import { jobScheduleService } from '@data/services/JobScheduleService'
|
||||
import { jobService } from '@data/services/JobService'
|
||||
import { sessionService } from '@data/services/SessionService'
|
||||
import { loggerService } from '@logger'
|
||||
import { readHeartbeat } from '@main/ai/agents/cherryclaw/heartbeat'
|
||||
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
|
||||
import { ChannelAdapterListener, type StreamListener } from '@main/ai/streamManager'
|
||||
import { startAgentSessionRun } from '@main/ai/streamManager/api/startAgentSessionRun'
|
||||
import { application } from '@main/core/application'
|
||||
import type { JobContext } from '@main/core/job/types'
|
||||
|
||||
const logger = loggerService.withContext('runAgentTask')
|
||||
|
||||
const HEARTBEAT_PROMPT_SENTINEL = '__heartbeat__'
|
||||
const HEARTBEAT_TASK_NAME = 'heartbeat'
|
||||
|
||||
export type AgentTaskInput = {
|
||||
agentId: string
|
||||
prompt: string
|
||||
timeoutMinutes: number
|
||||
}
|
||||
|
||||
export type AgentTaskOutput = {
|
||||
/** Session created for this fire. Persisted to `jobTable.output` purely as
|
||||
* an audit trail — the task scheduler never reads this back for continuity. */
|
||||
sessionId: string | null
|
||||
/** First 200 chars of the assistant reply, or a status marker for skipped runs. */
|
||||
result: string
|
||||
}
|
||||
|
||||
/** Combine the JobManager-provided abort signal with an optional per-task timeout. */
|
||||
function makeRunSignal(
|
||||
outerSignal: AbortSignal,
|
||||
timeoutMinutes: number | undefined
|
||||
): { signal: AbortSignal; dispose: () => void } {
|
||||
if (!timeoutMinutes || timeoutMinutes <= 0) {
|
||||
return { signal: outerSignal, dispose: () => {} }
|
||||
}
|
||||
// Own the timeout so `dispose()` can actually release the timer on normal
|
||||
// completion (an `AbortSignal.timeout` keeps a live timer until it fires).
|
||||
const timeoutController = new AbortController()
|
||||
const timer = setTimeout(
|
||||
() => timeoutController.abort(new Error(`Task timed out after ${timeoutMinutes} minute(s)`)),
|
||||
timeoutMinutes * 60_000
|
||||
)
|
||||
const signal = AbortSignal.any([outerSignal, timeoutController.signal])
|
||||
return { signal, dispose: () => clearTimeout(timer) }
|
||||
}
|
||||
|
||||
export async function runAgentTask(ctx: JobContext<AgentTaskInput>): Promise<AgentTaskOutput> {
|
||||
const { agentId, prompt, timeoutMinutes } = ctx.input
|
||||
|
||||
// schedule-fired jobs carry `scheduleId` on the row; manual ad-hoc enqueues
|
||||
// (no schedule) degrade gracefully: skip channel notification.
|
||||
const jobSnapshot = await jobService.getById(ctx.jobId)
|
||||
const scheduleId = jobSnapshot?.scheduleId ?? null
|
||||
const scheduleSnapshot = scheduleId ? await jobScheduleService.getById(scheduleId) : null
|
||||
const taskName = scheduleSnapshot?.name ?? null
|
||||
|
||||
const agent = await agentService.getAgent(agentId)
|
||||
if (!agent) {
|
||||
throw new Error(`Agent not found: ${agentId}`)
|
||||
}
|
||||
|
||||
const config = agent.configuration ?? {}
|
||||
|
||||
const isHeartbeat = taskName === HEARTBEAT_TASK_NAME && prompt === HEARTBEAT_PROMPT_SENTINEL
|
||||
|
||||
let effectivePrompt = prompt
|
||||
|
||||
// All heartbeat skip decisions happen BEFORE we create a session — `createSession`
|
||||
// lazily provisions a workspace, so creating one for a fire we're going to drop
|
||||
// accretes a session row (and workspace) every interval. The agent's workspace is
|
||||
// shared across its sessions, so we can read `heartbeat.md` without creating one.
|
||||
if (isHeartbeat) {
|
||||
if (config.heartbeat_enabled === false) {
|
||||
logger.debug('Heartbeat skipped (disabled)', { agentId, scheduleId })
|
||||
return { sessionId: null, result: 'Skipped (disabled)' }
|
||||
}
|
||||
const workspacePath = await sessionService.findAgentWorkspacePath(agentId)
|
||||
if (!workspacePath) {
|
||||
logger.debug('Heartbeat skipped (no workspace)', { agentId, scheduleId })
|
||||
return { sessionId: null, result: 'Skipped (no file)' }
|
||||
}
|
||||
const content = await readHeartbeat(workspacePath)
|
||||
if (!content) {
|
||||
logger.debug('Heartbeat skipped (no heartbeat.md)', { agentId, scheduleId })
|
||||
return { sessionId: null, result: 'Skipped (no file)' }
|
||||
}
|
||||
effectivePrompt = [
|
||||
'[Heartbeat]',
|
||||
'This is a periodic heartbeat. The instructions below are from your heartbeat.md file.',
|
||||
'Process each item, take action where possible, and use the notify tool to alert the user of important results.',
|
||||
'',
|
||||
'---',
|
||||
content
|
||||
].join('\n')
|
||||
}
|
||||
|
||||
// Always create a fresh session per fire. Scheduled tasks are discrete
|
||||
// invocations; cross-fire session reuse would only carry stale model
|
||||
// context. Persistent state lives in workspace files (heartbeat.md, etc.).
|
||||
const session = await sessionService.createSession({ agentId, name: taskName ?? 'Scheduled task' })
|
||||
|
||||
const subscribedChannels = scheduleId ? await agentChannelService.getSubscribedChannels(scheduleId) : []
|
||||
|
||||
const channelManager = application.get('ChannelManager')
|
||||
const channelListeners: StreamListener[] = subscribedChannels.flatMap((ch) => {
|
||||
const adapter = channelManager.getAdapter(ch.id)
|
||||
if (!adapter) return []
|
||||
// Suppress the listener's generic `Error: …` — `notifyTaskError` below sends a richer
|
||||
// `[Task failed]` summary to the same chats, so leaving it on would double-notify.
|
||||
return adapter.notifyChatIds.map((chatId) => new ChannelAdapterListener(adapter, chatId, true))
|
||||
})
|
||||
|
||||
const { signal: runSignal, dispose } = makeRunSignal(ctx.signal, timeoutMinutes)
|
||||
const startTimeMs = Date.now()
|
||||
|
||||
let resolveExecution!: (text: string) => void
|
||||
let rejectExecution!: (err: unknown) => void
|
||||
const executionDone = new Promise<string>((resolve, reject) => {
|
||||
resolveExecution = resolve
|
||||
rejectExecution = reject
|
||||
})
|
||||
let accumulatedText = ''
|
||||
const sentinel: StreamListener = {
|
||||
id: `agent-task:${scheduleId ?? ctx.jobId}`,
|
||||
onChunk(chunk) {
|
||||
// `text-delta`'s field is `delta`, not `text` (AI SDK `UIMessageChunk`) — the
|
||||
// previous `as { text }` cast silently never accumulated, so the persisted
|
||||
// result was always the `'Completed'` fallback.
|
||||
if (chunk.type === 'text-delta') accumulatedText += chunk.delta
|
||||
},
|
||||
onDone() {
|
||||
resolveExecution(accumulatedText.trim())
|
||||
},
|
||||
onPaused() {
|
||||
if (runSignal.aborted) {
|
||||
const reason = runSignal.reason
|
||||
rejectExecution(reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted')))
|
||||
return
|
||||
}
|
||||
resolveExecution(accumulatedText.trim())
|
||||
},
|
||||
onError(result) {
|
||||
rejectExecution(new Error(result.error.message ?? 'Execution failed'))
|
||||
},
|
||||
// Keep `true`: the manager prunes a listener whose `isAlive()` is false BEFORE
|
||||
// firing its terminal callback, so gating on `runSignal` here would make an
|
||||
// aborted run's terminal event never settle `executionDone`. Abort is handled
|
||||
// explicitly via `onRunAbort` below.
|
||||
isAlive: () => true
|
||||
}
|
||||
|
||||
const topicId = buildAgentSessionTopicId(session.id)
|
||||
// On JobManager cancel or per-task timeout, stop the upstream run: the execution's
|
||||
// own controller never sees `runSignal`, so abort the live stream and settle
|
||||
// `executionDone` here — otherwise the handler promise leaks until the JobManager's
|
||||
// force-finalize timeout.
|
||||
const onRunAbort = () => {
|
||||
const reason = runSignal.reason
|
||||
application
|
||||
.get('AiStreamManager')
|
||||
.abort(topicId, reason instanceof Error ? reason.message : String(reason ?? 'task-aborted'))
|
||||
rejectExecution(reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted')))
|
||||
}
|
||||
if (runSignal.aborted) onRunAbort()
|
||||
else runSignal.addEventListener('abort', onRunAbort, { once: true })
|
||||
|
||||
let runError: Error | null = null
|
||||
let resultText = ''
|
||||
try {
|
||||
await startAgentSessionRun({
|
||||
sessionId: session.id,
|
||||
userParts: [{ type: 'text', text: effectivePrompt }],
|
||||
listeners: [sentinel, ...channelListeners]
|
||||
})
|
||||
|
||||
resultText = await executionDone
|
||||
|
||||
if (runSignal.aborted) {
|
||||
const reason = runSignal.reason
|
||||
throw reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted'))
|
||||
}
|
||||
} catch (err) {
|
||||
runError = err instanceof Error ? err : new Error(String(err))
|
||||
if (!runSignal.aborted && subscribedChannels.length > 0) {
|
||||
await notifyTaskError(
|
||||
{ id: scheduleId, name: taskName, durationMs: Date.now() - startTimeMs },
|
||||
runError.message,
|
||||
subscribedChannels
|
||||
)
|
||||
}
|
||||
throw runError
|
||||
} finally {
|
||||
runSignal.removeEventListener('abort', onRunAbort)
|
||||
dispose()
|
||||
}
|
||||
|
||||
return {
|
||||
sessionId: session.id,
|
||||
result: resultText.slice(0, 200) || 'Completed'
|
||||
}
|
||||
}
|
||||
|
||||
async function notifyTaskError(
|
||||
task: { id: string | null; name: string | null; durationMs: number },
|
||||
error: string,
|
||||
subscribedChannels: Array<{ id: string }>
|
||||
): Promise<void> {
|
||||
const channelManager = application.get('ChannelManager')
|
||||
try {
|
||||
const durationSec = Math.round(task.durationMs / 1000)
|
||||
const label = task.name ?? task.id ?? '(unknown)'
|
||||
const text = `[Task failed] ${label}\nDuration: ${durationSec}s\nError: ${error}`
|
||||
|
||||
for (const ch of subscribedChannels) {
|
||||
const adapter = channelManager.getAdapter(ch.id)
|
||||
if (!adapter) continue
|
||||
for (const chatId of adapter.notifyChatIds) {
|
||||
adapter.sendMessage(chatId, text).catch((err) => {
|
||||
logger.warn('Failed to deliver task error notification', {
|
||||
scheduleId: task.id,
|
||||
channelId: ch.id,
|
||||
chatId,
|
||||
error: err instanceof Error ? err.message : String(err)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.warn('Error while building task error notification', {
|
||||
scheduleId: task.id,
|
||||
error: err instanceof Error ? err.message : String(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,94 +1,9 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { FileAttachment, ImageAttachment } from '@main/utils/downloadAsBase64'
|
||||
import type { ChannelLogEntry, ChannelLogLevel, ChannelStatusEvent } from '@shared/config/types'
|
||||
import { net } from 'electron'
|
||||
import type { AgentChannelEntity, AgentChannelType } from '@shared/data/api/schemas/agentChannels'
|
||||
import { EventEmitter } from 'events'
|
||||
|
||||
const logger = loggerService.withContext('ChannelAdapter')
|
||||
|
||||
/** Pre-downloaded, base64-encoded image ready for multimodal AI input. */
|
||||
export type ImageAttachment = {
|
||||
data: string // base64-encoded image bytes
|
||||
media_type: string // e.g. 'image/png', 'image/jpeg', 'image/gif', 'image/webp'
|
||||
}
|
||||
|
||||
/** Pre-downloaded, base64-encoded file attachment from an IM channel. */
|
||||
export type FileAttachment = {
|
||||
filename: string // original filename, e.g. 'report.pdf'
|
||||
data: string // base64-encoded file bytes
|
||||
media_type: string // MIME type, e.g. 'application/pdf', 'text/plain'
|
||||
size: number // raw byte size (before base64 encoding)
|
||||
}
|
||||
|
||||
/** Maximum file size we'll download from IM channels (20 MB). */
|
||||
export const MAX_FILE_SIZE_BYTES = 20 * 1024 * 1024
|
||||
|
||||
/**
|
||||
* Download an image URL via Electron's net.fetch (respects system proxy) and
|
||||
* return base64-encoded data. Returns null on failure.
|
||||
*/
|
||||
export async function downloadImageAsBase64(url: string): Promise<ImageAttachment | null> {
|
||||
try {
|
||||
const response = await net.fetch(url)
|
||||
if (!response.ok) {
|
||||
logger.warn('Failed to download image', { url, status: response.status })
|
||||
return null
|
||||
}
|
||||
const contentType = response.headers.get('content-type') || 'image/png'
|
||||
const mediaType = contentType.split(';')[0].trim()
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
return { data: buffer.toString('base64'), media_type: mediaType }
|
||||
} catch (error) {
|
||||
logger.warn('Failed to fetch image', {
|
||||
url,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Download a file URL via Electron's net.fetch and return base64-encoded data.
|
||||
* Enforces MAX_FILE_SIZE_BYTES. Returns null on failure or if the file is too large.
|
||||
*/
|
||||
export async function downloadFileAsBase64(url: string, filename: string): Promise<FileAttachment | null> {
|
||||
try {
|
||||
const response = await net.fetch(url)
|
||||
if (!response.ok) {
|
||||
logger.warn('Failed to download file', { url, filename, status: response.status })
|
||||
return null
|
||||
}
|
||||
|
||||
const contentLength = response.headers.get('content-length')
|
||||
if (contentLength && parseInt(contentLength, 10) > MAX_FILE_SIZE_BYTES) {
|
||||
logger.warn('File too large, skipping download', { filename, size: contentLength })
|
||||
return null
|
||||
}
|
||||
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
if (buffer.length > MAX_FILE_SIZE_BYTES) {
|
||||
logger.warn('File too large after download', { filename, size: buffer.length })
|
||||
return null
|
||||
}
|
||||
|
||||
const contentType = response.headers.get('content-type') || 'application/octet-stream'
|
||||
const mediaType = contentType.split(';')[0].trim()
|
||||
|
||||
return {
|
||||
filename,
|
||||
data: buffer.toString('base64'),
|
||||
media_type: mediaType,
|
||||
size: buffer.length
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to fetch file', {
|
||||
url,
|
||||
filename,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export type ChannelMessageEvent = {
|
||||
chatId: string
|
||||
userId: string
|
||||
@@ -113,11 +28,19 @@ export type SendMessageOptions = {
|
||||
replyToMessageId?: number
|
||||
}
|
||||
|
||||
export type ChannelAdapterConfig = {
|
||||
/** Channel type → its config payload, projected from the `AgentChannelEntity` discriminated union. */
|
||||
type ChannelConfigByType = {
|
||||
[T in AgentChannelType]: Extract<AgentChannelEntity, { type: T }>['config']
|
||||
}
|
||||
|
||||
/** The config payload for a specific channel type. Indexing keeps `ChannelAdapterConfig` covariant in `T`. */
|
||||
export type ChannelConfigForType<T extends AgentChannelType = AgentChannelType> = ChannelConfigByType[T]
|
||||
|
||||
export type ChannelAdapterConfig<T extends AgentChannelType = AgentChannelType> = {
|
||||
channelId: string
|
||||
channelType: string
|
||||
channelType: T
|
||||
agentId: string
|
||||
channelConfig: Record<string, unknown>
|
||||
channelConfig: ChannelConfigForType<T>
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -139,7 +62,7 @@ export type ChannelAdapterConfig = {
|
||||
*/
|
||||
export abstract class ChannelAdapter extends EventEmitter {
|
||||
readonly channelId: string
|
||||
readonly channelType: string
|
||||
readonly channelType: AgentChannelType
|
||||
readonly agentId: string
|
||||
/**
|
||||
* Chat IDs that this adapter can send notifications/task results to.
|
||||
@@ -1,24 +1,34 @@
|
||||
import { application } from '@application'
|
||||
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
|
||||
import { WindowType } from '@main/core/window/types'
|
||||
import type { ChannelLogEntry, ChannelStatusEvent } from '@shared/config/types'
|
||||
import type { AgentChannelEntity as ChannelRow } from '@shared/data/api/schemas/agentChannels'
|
||||
import type { AgentChannelEntity as ChannelRow, AgentChannelType } from '@shared/data/api/schemas/agentChannels'
|
||||
import type { ChannelConfig } from '@shared/data/types/channel'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
|
||||
import type { ChannelAdapter } from './ChannelAdapter'
|
||||
import type { ChannelConfig } from './channelConfig'
|
||||
import { ChannelLogBuffer } from './ChannelLogBuffer'
|
||||
import { channelMessageHandler } from './ChannelMessageHandler'
|
||||
|
||||
const logger = loggerService.withContext('ChannelManager')
|
||||
|
||||
// Adapter factory registry -- adapters register themselves here
|
||||
type AdapterFactory = (channel: ChannelRow, agentId: string) => ChannelAdapter
|
||||
const adapterFactories = new Map<string, AdapterFactory>()
|
||||
// Adapter factory registry -- adapters register themselves here. The factory
|
||||
// for a given channel type receives the matching variant of the discriminated
|
||||
// `ChannelRow` union, so `channel.config` is strongly typed per adapter.
|
||||
type AdapterFactory<T extends AgentChannelType = AgentChannelType> = (
|
||||
channel: Extract<ChannelRow, { type: T }>,
|
||||
agentId: string
|
||||
) => ChannelAdapter
|
||||
const adapterFactories = new Map<AgentChannelType, AdapterFactory>()
|
||||
|
||||
export function registerAdapterFactory(type: string, factory: AdapterFactory): void {
|
||||
adapterFactories.set(type, factory)
|
||||
export function registerAdapterFactory<T extends AgentChannelType>(type: T, factory: AdapterFactory<T>): void {
|
||||
// A factory is always stored under, and looked up by, its own channel type
|
||||
// (see `connectChannelFromRow`), so the row handed to it is guaranteed to be
|
||||
// this variant. That invariant is the one thing the type system can't see, so
|
||||
// we narrow the row to the factory's variant here — nothing wider is asserted.
|
||||
adapterFactories.set(type, (channel, agentId) => factory(channel as Extract<ChannelRow, { type: T }>, agentId))
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -26,7 +36,7 @@ export function registerAdapterFactory(type: string, factory: AdapterFactory): v
|
||||
* Each module registers itself via `registerAdapterFactory()` as a side effect.
|
||||
* This avoids eagerly importing all 6 heavy adapter modules at startup.
|
||||
*/
|
||||
const adapterImportMap: Record<string, () => Promise<unknown>> = {
|
||||
const adapterImportMap: Record<AgentChannelType, () => Promise<unknown>> = {
|
||||
discord: () => import('./adapters/discord/DiscordAdapter'),
|
||||
feishu: () => import('./adapters/feishu/FeishuAdapter'),
|
||||
qq: () => import('./adapters/qq/QqAdapter'),
|
||||
@@ -36,15 +46,15 @@ const adapterImportMap: Record<string, () => Promise<unknown>> = {
|
||||
}
|
||||
|
||||
/** Ensure the adapter factory for the given type is loaded (idempotent). */
|
||||
async function ensureAdapterLoaded(type: string): Promise<void> {
|
||||
async function ensureAdapterLoaded(type: AgentChannelType): Promise<void> {
|
||||
if (adapterFactories.has(type)) return
|
||||
const loader = adapterImportMap[type]
|
||||
if (!loader) return
|
||||
await loader()
|
||||
await adapterImportMap[type]()
|
||||
}
|
||||
|
||||
class ChannelManager {
|
||||
private static instance: ChannelManager | null = null
|
||||
@Injectable('ChannelManager')
|
||||
@ServicePhase(Phase.WhenReady)
|
||||
@DependsOn(['WindowManager'])
|
||||
export class ChannelManager extends BaseService {
|
||||
private readonly adapters = new Map<string, ChannelAdapter>() // key: `${agentId}:${channelId}`
|
||||
private readonly qrWaiters = new Map<
|
||||
string,
|
||||
@@ -53,11 +63,12 @@ class ChannelManager {
|
||||
private readonly channelLogs = new ChannelLogBuffer()
|
||||
private readonly channelStatuses = new Map<string, ChannelStatusEvent>()
|
||||
|
||||
static getInstance(): ChannelManager {
|
||||
if (!ChannelManager.instance) {
|
||||
ChannelManager.instance = new ChannelManager()
|
||||
}
|
||||
return ChannelManager.instance
|
||||
protected async onReady(): Promise<void> {
|
||||
await this.start()
|
||||
}
|
||||
|
||||
protected async onStop(): Promise<void> {
|
||||
await this.stop()
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
@@ -394,5 +405,3 @@ class ChannelManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const channelManager = ChannelManager.getInstance()
|
||||
@@ -4,27 +4,22 @@ import path from 'node:path'
|
||||
|
||||
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
|
||||
import { agentService } from '@data/services/AgentService'
|
||||
import { agentSessionService as sessionService } from '@data/services/AgentSessionService'
|
||||
import { sessionService } from '@data/services/SessionService'
|
||||
import { loggerService } from '@logger'
|
||||
import { sessionMessageOrchestrator } from '@main/services/agents/services/SessionMessageOrchestrator'
|
||||
import type { GetAgentSessionResponse, PermissionMode } from '@types'
|
||||
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
|
||||
import { isAgentSessionWorkspaceError } from '@main/ai/runtime/claudeCode/settingsBuilder'
|
||||
import { ChannelAdapterListener, type StreamListener } from '@main/ai/streamManager'
|
||||
import { startAgentSessionRun } from '@main/ai/streamManager/api/startAgentSessionRun'
|
||||
import { application } from '@main/core/application'
|
||||
import type { FileAttachment, ImageAttachment } from '@main/utils/downloadAsBase64'
|
||||
import type { AgentSessionEntity } from '@shared/data/api/schemas/sessions'
|
||||
|
||||
import { sanitizeChannelOutput, wrapExternalContent } from '../security'
|
||||
import type {
|
||||
ChannelAdapter,
|
||||
ChannelCommandEvent,
|
||||
ChannelMessageEvent,
|
||||
FileAttachment,
|
||||
ImageAttachment
|
||||
} from './ChannelAdapter'
|
||||
import type { ChannelAdapter, ChannelCommandEvent, ChannelMessageEvent } from './ChannelAdapter'
|
||||
import { SLASH_COMMANDS } from './constants'
|
||||
import { sessionStreamBus } from './SessionStreamBus'
|
||||
import { broadcastSessionChanged } from './sessionStreamIpc'
|
||||
import { splitMessage } from './utils'
|
||||
import { wrapExternalContent } from './security'
|
||||
|
||||
const logger = loggerService.withContext('ChannelMessageHandler')
|
||||
|
||||
const MAX_MESSAGE_LENGTH = 4096
|
||||
const TYPING_INTERVAL_MS = 4000
|
||||
|
||||
/** Max number of entries in the session tracker before evicting oldest entries. */
|
||||
@@ -54,7 +49,7 @@ export class ChannelMessageHandler {
|
||||
private static instance: ChannelMessageHandler | null = null
|
||||
// TODO: in v2 use cacheService
|
||||
private readonly sessionTracker = new Map<string, string>() // `${agentId}:${channelId}:${chatId}` -> sessionId
|
||||
private readonly pendingResolutions = new Map<string, Promise<GetAgentSessionResponse | null>>()
|
||||
private readonly pendingResolutions = new Map<string, Promise<AgentSessionEntity | null>>()
|
||||
/** Per-chat debounce buffer — accumulates rapid messages before flushing */
|
||||
private readonly pendingBatches = new Map<string, PendingBatch>()
|
||||
/** Per-chat serial queue — ensures only one stream runs at a time per chat */
|
||||
@@ -192,12 +187,24 @@ export class ChannelMessageHandler {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply channel-level permission mode override on every message (not just session creation).
|
||||
// This ensures changes to the channel's permission_mode take effect immediately,
|
||||
// even for sessions created before the setting was changed.
|
||||
await this.applyChannelPermissionMode(session, adapter.channelId)
|
||||
// Resolve agent for cognitive config (model / configuration / mcps / allowedTools).
|
||||
// Workspace is read from the session itself (CMA Environment binding).
|
||||
// An orphan session (`agentId === null`) cannot run; skip it.
|
||||
if (!session.agentId) {
|
||||
logger.error('Channel message hit an orphan session', { sessionId: session.id })
|
||||
return
|
||||
}
|
||||
const agent = await agentService.getAgent(session.agentId)
|
||||
if (!agent) {
|
||||
logger.error('Agent not found for session', { sessionId: session.id, agentId: session.agentId })
|
||||
return
|
||||
}
|
||||
|
||||
const workDir = session.accessiblePaths[0]
|
||||
// TODO(channel-perm-override): channel-level permission_mode used to mutate
|
||||
// session.configuration in-place; with config now living on agent, this
|
||||
// override needs to flow as a per-dispatch option instead. Tracked separately.
|
||||
|
||||
const workDir = session.workspace?.path
|
||||
|
||||
// Save images to agent workspace so the agent can read them via the Read tool
|
||||
let imagePaths: string[] = []
|
||||
@@ -252,36 +259,6 @@ export class ChannelMessageHandler {
|
||||
channelType: adapter.channelType
|
||||
})
|
||||
|
||||
// Build display text: append filenames so the user can see them in the UI
|
||||
let displayText = message.text
|
||||
if (message.files && message.files.length > 0) {
|
||||
const names = message.files.map((f) => `📎 ${f.filename}`).join('\n')
|
||||
displayText = displayText ? `${displayText}\n${names}` : names
|
||||
}
|
||||
|
||||
// Snapshot subscriber state ONCE — this single check drives:
|
||||
// 1. Whether user-message is published to the renderer
|
||||
// 2. The persist flag (renderer persistence vs headless persistence)
|
||||
// 3. Whether stream chunks / complete events are forwarded
|
||||
// Checking once eliminates the race where subscribe() IPC completes
|
||||
// between the user-message publish and the persist decision.
|
||||
const rendererIsWatching = sessionStreamBus.hasSubscribers(session.id)
|
||||
|
||||
if (rendererIsWatching) {
|
||||
sessionStreamBus.publish(session.id, {
|
||||
sessionId: session.id,
|
||||
agentId: session.agentId,
|
||||
type: 'user-message',
|
||||
userMessage: {
|
||||
chatId: message.chatId,
|
||||
userId: message.userId,
|
||||
userName: message.userName,
|
||||
text: displayText,
|
||||
images: message.images
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const abortController = new AbortController()
|
||||
this.activeAbortControllers.set(session.id, abortController)
|
||||
|
||||
@@ -293,30 +270,22 @@ export class ChannelMessageHandler {
|
||||
)
|
||||
|
||||
try {
|
||||
const responseText = await this.collectStreamResponse(
|
||||
session,
|
||||
securedContent,
|
||||
abortController,
|
||||
adapter,
|
||||
message.chatId,
|
||||
message.text,
|
||||
message.images,
|
||||
rendererIsWatching
|
||||
)
|
||||
|
||||
if (responseText) {
|
||||
// Sanitize output to prevent accidental secret leakage through channels
|
||||
const { text: sanitizedText } = sanitizeChannelOutput(responseText)
|
||||
const finalized = await adapter.onStreamComplete(message.chatId, sanitizedText).catch(() => false)
|
||||
if (!finalized) {
|
||||
await this.sendChunked(adapter, message.chatId, sanitizedText)
|
||||
}
|
||||
}
|
||||
// Delivery (streaming updates + the sanitized finalize) is owned by the
|
||||
// `ChannelAdapterListener` registered inside `collectStreamResponse`; we only await
|
||||
// turn completion here. (The old post-hoc finalize was dead — the sentinel's `c.text`
|
||||
// read never accumulated — and reviving it would double-send.)
|
||||
await this.collectStreamResponse(session, securedContent, abortController, adapter, message.chatId)
|
||||
} catch (streamError) {
|
||||
// Notify adapter of the error so it can update streaming UI
|
||||
adapter
|
||||
.onStreamError(message.chatId, streamError instanceof Error ? streamError.message : String(streamError))
|
||||
.catch(() => {})
|
||||
const streamErrorMessage = streamError instanceof Error ? streamError.message : String(streamError)
|
||||
if (isAgentSessionWorkspaceError(streamError)) {
|
||||
// Thrown before streaming starts (validateSession), so no controller exists yet and
|
||||
// onStreamError is a no-op on most adapters — send a plain message so the inbound
|
||||
// message isn't silently dropped on Telegram/WeChat/QQ/Discord/Slack.
|
||||
adapter.sendMessage(message.chatId, streamErrorMessage).catch(() => {})
|
||||
} else {
|
||||
// Mid-stream error: let the adapter update its streaming UI.
|
||||
adapter.onStreamError(message.chatId, streamErrorMessage).catch(() => {})
|
||||
}
|
||||
throw streamError
|
||||
} finally {
|
||||
this.activeAbortControllers.delete(session.id)
|
||||
@@ -336,28 +305,14 @@ export class ChannelMessageHandler {
|
||||
try {
|
||||
switch (command.command) {
|
||||
case 'new': {
|
||||
const agent = await agentService.getAgent(agentId)
|
||||
const channelRow = await channelService.getChannel(adapter.channelId)
|
||||
const permMode = channelRow?.permissionMode as PermissionMode | undefined
|
||||
|
||||
const newSession = await sessionService.createSession(agentId, {
|
||||
...(agent?.configuration
|
||||
? {
|
||||
configuration: {
|
||||
...agent.configuration,
|
||||
...(permMode ? { permission_mode: permMode } : {})
|
||||
}
|
||||
}
|
||||
: {})
|
||||
})
|
||||
if (newSession) {
|
||||
// Update channel's session_id to point to the new session
|
||||
await channelService.updateChannel(adapter.channelId, { sessionId: newSession.id })
|
||||
const trackerKey = `${agentId}:${adapter.channelId}:${command.chatId}`
|
||||
this.sessionTracker.set(trackerKey, newSession.id)
|
||||
this.evictSessionTracker()
|
||||
await adapter.sendMessage(command.chatId, 'New session created.')
|
||||
}
|
||||
// TODO(channel-perm-override): channel.permissionMode no longer
|
||||
// applied here — config lives on agent now. Tracked separately.
|
||||
const newSession = await sessionService.createSession({ agentId, name: 'Channel session' })
|
||||
await channelService.updateChannel(adapter.channelId, { sessionId: newSession.id })
|
||||
const trackerKey = `${agentId}:${adapter.channelId}:${command.chatId}`
|
||||
this.sessionTracker.set(trackerKey, newSession.id)
|
||||
this.evictSessionTracker()
|
||||
await adapter.sendMessage(command.chatId, 'New session created.')
|
||||
break
|
||||
}
|
||||
case 'compact': {
|
||||
@@ -380,7 +335,12 @@ export class ChannelMessageHandler {
|
||||
adapter,
|
||||
command.chatId
|
||||
)
|
||||
await adapter.sendMessage(command.chatId, response || 'Session compacted.')
|
||||
// The `ChannelAdapterListener` registered inside `collectStreamResponse` already
|
||||
// delivered any non-empty output; only send an explicit fallback when compact
|
||||
// produced no text, so we don't double-send.
|
||||
if (!response) {
|
||||
await adapter.sendMessage(command.chatId, 'Session compacted.')
|
||||
}
|
||||
} finally {
|
||||
clearInterval(typingInterval)
|
||||
}
|
||||
@@ -431,22 +391,6 @@ export class ChannelMessageHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Look up the channel's current permission_mode from the agent config and
|
||||
* override the session's configuration in-place. This ensures that changes
|
||||
* to the channel permission mode take effect immediately — even for sessions
|
||||
* that were created before the setting was changed.
|
||||
*/
|
||||
private async applyChannelPermissionMode(session: GetAgentSessionResponse, channelId: string): Promise<void> {
|
||||
const channel = await channelService.getChannel(channelId)
|
||||
if (channel?.permissionMode && session.configuration) {
|
||||
session.configuration = {
|
||||
...session.configuration,
|
||||
permission_mode: channel.permissionMode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Evict oldest session tracker entries when the map exceeds the size limit. */
|
||||
private evictSessionTracker(): void {
|
||||
if (this.sessionTracker.size <= SESSION_TRACKER_MAX_SIZE) return
|
||||
@@ -460,15 +404,27 @@ export class ChannelMessageHandler {
|
||||
|
||||
/** Clear session tracking for an agent (used when agent is deleted/updated) */
|
||||
clearSessionTracker(agentId: string): void {
|
||||
for (const key of this.sessionTracker.keys()) {
|
||||
// Abort any in-flight stream owned by a tracked session of this agent
|
||||
// before dropping the tracker entries — otherwise the stream keeps
|
||||
// running on a deleted agent and `sendMessage` to a now-detached
|
||||
// channel will throw.
|
||||
const sessionIdsToAbort: string[] = []
|
||||
for (const [key, sessionId] of this.sessionTracker.entries()) {
|
||||
if (key.startsWith(`${agentId}:`)) {
|
||||
sessionIdsToAbort.push(sessionId)
|
||||
this.sessionTracker.delete(key)
|
||||
}
|
||||
}
|
||||
for (const sessionId of sessionIdsToAbort) {
|
||||
this.abortSessionStream(sessionId, 'agent-cleared')
|
||||
}
|
||||
for (const [key, batch] of this.pendingBatches.entries()) {
|
||||
if (key.startsWith(`${agentId}:`)) {
|
||||
clearTimeout(batch.timer)
|
||||
this.pendingBatches.delete(key)
|
||||
// Settle the discarded batch's callers so their .catch handlers fire
|
||||
// instead of leaving handleIncoming promises hanging forever.
|
||||
batch.resolvers.forEach((r) => r.reject(new Error('Agent removed; batch discarded')))
|
||||
}
|
||||
}
|
||||
for (const key of this.chatQueues.keys()) {
|
||||
@@ -478,14 +434,22 @@ export class ChannelMessageHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/** Abort an active stream for the given session. Returns true if aborted. */
|
||||
/** Abort an active stream for the given session. Returns true if a stream was in flight. */
|
||||
abortSession(sessionId: string): boolean {
|
||||
const controller = this.activeAbortControllers.get(sessionId)
|
||||
if (controller) {
|
||||
controller.abort()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
if (!this.activeAbortControllers.has(sessionId)) return false
|
||||
this.abortSessionStream(sessionId, 'channel-session-aborted')
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the upstream agent-session turn for a session. The local `AbortController`
|
||||
* is never passed to the running stream — it only flips a listener's `isAlive()`,
|
||||
* which (because the manager prunes dead listeners before firing their terminal
|
||||
* callback) would strand the completion sentinel. So abort through the manager,
|
||||
* which settles the turn as `paused` and lets the still-alive sentinel resolve.
|
||||
*/
|
||||
private abortSessionStream(sessionId: string, reason: string): void {
|
||||
application.get('AiStreamManager').abort(buildAgentSessionTopicId(sessionId), reason)
|
||||
}
|
||||
|
||||
private async resolveSession(
|
||||
@@ -493,7 +457,7 @@ export class ChannelMessageHandler {
|
||||
channelId: string,
|
||||
channelType: string,
|
||||
chatId: string
|
||||
): Promise<GetAgentSessionResponse | null> {
|
||||
): Promise<AgentSessionEntity | null> {
|
||||
const trackerKey = `${agentId}:${channelId}:${chatId}`
|
||||
|
||||
// Coalesce concurrent resolutions for the same chat to avoid duplicate sessions
|
||||
@@ -515,15 +479,15 @@ export class ChannelMessageHandler {
|
||||
_channelType: string,
|
||||
_chatId: string,
|
||||
trackerKey: string
|
||||
): Promise<GetAgentSessionResponse | null> {
|
||||
): Promise<AgentSessionEntity | null> {
|
||||
const channelRow = await channelService.getChannel(channelId)
|
||||
const lookup = async (sessionId: string) => sessionService.getById(sessionId).catch(() => null)
|
||||
|
||||
// Check tracker first
|
||||
const trackedId = this.sessionTracker.get(trackerKey)
|
||||
if (trackedId) {
|
||||
const session = await sessionService.getSession(agentId, trackedId)
|
||||
if (session) {
|
||||
// Ensure channel's session_id stays in sync
|
||||
const session = await lookup(trackedId)
|
||||
if (session && session.agentId === agentId) {
|
||||
if (channelRow && channelRow.sessionId !== session.id) {
|
||||
channelService
|
||||
.updateChannel(channelId, { sessionId: session.id })
|
||||
@@ -531,19 +495,18 @@ export class ChannelMessageHandler {
|
||||
logger.warn('Failed to sync channel-session link', err instanceof Error ? err : new Error(String(err)))
|
||||
)
|
||||
}
|
||||
return session as GetAgentSessionResponse
|
||||
return session
|
||||
}
|
||||
// Tracked session gone, clear it
|
||||
this.sessionTracker.delete(trackerKey)
|
||||
}
|
||||
|
||||
// Look up existing session via channel's session_id
|
||||
if (channelRow?.sessionId) {
|
||||
const existingSession = await sessionService.getSession(agentId, channelRow.sessionId)
|
||||
if (existingSession) {
|
||||
const existingSession = await lookup(channelRow.sessionId)
|
||||
if (existingSession && existingSession.agentId === agentId) {
|
||||
this.sessionTracker.set(trackerKey, existingSession.id)
|
||||
this.evictSessionTracker()
|
||||
return existingSession as GetAgentSessionResponse
|
||||
return existingSession
|
||||
}
|
||||
}
|
||||
|
||||
@@ -554,136 +517,57 @@ export class ChannelMessageHandler {
|
||||
channelSessionId: channelRow?.sessionId ?? null,
|
||||
trackerKey
|
||||
})
|
||||
const agent = await agentService.getAgent(agentId)
|
||||
const channelPermissionMode = channelRow?.permissionMode as PermissionMode | undefined
|
||||
|
||||
const newSession = await sessionService.createSession(agentId, {
|
||||
...(agent?.configuration
|
||||
? {
|
||||
configuration: {
|
||||
...agent.configuration,
|
||||
...(channelPermissionMode ? { permission_mode: channelPermissionMode } : {})
|
||||
}
|
||||
}
|
||||
: {})
|
||||
})
|
||||
if (newSession) {
|
||||
// Link channel to the new session
|
||||
await channelService.updateChannel(channelId, { sessionId: newSession.id })
|
||||
this.sessionTracker.set(trackerKey, newSession.id)
|
||||
this.evictSessionTracker()
|
||||
return newSession as GetAgentSessionResponse
|
||||
}
|
||||
|
||||
return null
|
||||
const newSession = await sessionService.createSession({ agentId, name: 'Channel session' })
|
||||
await channelService.updateChannel(channelId, { sessionId: newSession.id })
|
||||
this.sessionTracker.set(trackerKey, newSession.id)
|
||||
this.evictSessionTracker()
|
||||
return newSession
|
||||
}
|
||||
|
||||
private async collectStreamResponse(
|
||||
session: GetAgentSessionResponse,
|
||||
session: AgentSessionEntity,
|
||||
content: string,
|
||||
abortController: AbortController,
|
||||
adapter: ChannelAdapter,
|
||||
chatId: string,
|
||||
displayContent?: string,
|
||||
images?: ImageAttachment[],
|
||||
rendererIsWatching: boolean = false
|
||||
chatId: string
|
||||
): Promise<string> {
|
||||
// Use the pre-computed rendererIsWatching flag from processIncoming.
|
||||
// When renderer is watching: persist=false (renderer handles rich block persistence),
|
||||
// stream chunks and events are forwarded to the renderer via the bus.
|
||||
// When renderer is NOT watching: persist=true (main persists via persistHeadlessExchange),
|
||||
// stream events are NOT forwarded (no subscriber or subscriber arrived late).
|
||||
const { stream, completion } = await sessionMessageOrchestrator.createSessionMessage(
|
||||
session,
|
||||
{ content },
|
||||
abortController,
|
||||
{ persist: !rendererIsWatching, displayContent, images }
|
||||
)
|
||||
|
||||
const reader = stream.getReader()
|
||||
let completedText = '' // text from finished blocks/turns
|
||||
let currentBlockText = '' // cumulative text within the current block
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
// Only forward chunks to renderer when it was confirmed watching at stream start.
|
||||
// This prevents late-subscribing renderers from receiving partial chunks
|
||||
// while main process is also persisting (which would cause duplicates).
|
||||
if (rendererIsWatching) {
|
||||
sessionStreamBus.publish(session.id, {
|
||||
sessionId: session.id,
|
||||
agentId: session.agentId,
|
||||
type: 'chunk',
|
||||
chunk: value
|
||||
})
|
||||
}
|
||||
|
||||
// Skip user message echoes — only accumulate assistant text for the channel reply
|
||||
const rawType = (value as any).providerMetadata?.raw?.type
|
||||
if (rawType === 'user') continue
|
||||
|
||||
switch (value.type) {
|
||||
case 'text-delta':
|
||||
// text-delta values are cumulative within a block
|
||||
if (value.text) {
|
||||
currentBlockText = value.text
|
||||
// Notify adapter of text update — adapter owns its own throttle/flush
|
||||
const fullText = completedText + currentBlockText
|
||||
adapter.onTextUpdate(chatId, fullText).catch(() => {})
|
||||
}
|
||||
break
|
||||
case 'text-end':
|
||||
// Block finished — commit current block text and reset for next turn
|
||||
if (currentBlockText) {
|
||||
completedText += currentBlockText + '\n\n'
|
||||
currentBlockText = ''
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
await completion
|
||||
|
||||
if (rendererIsWatching) {
|
||||
// Notify renderer that stream is complete and data is persisted
|
||||
sessionStreamBus.publish(session.id, {
|
||||
sessionId: session.id,
|
||||
agentId: session.agentId,
|
||||
type: 'complete'
|
||||
})
|
||||
}
|
||||
// headless=true means main process persisted; renderer should force-reload from DB.
|
||||
// headless=false means renderer handled persistence; no reload needed.
|
||||
broadcastSessionChanged(session.agentId, session.id, !rendererIsWatching)
|
||||
|
||||
// Trim trailing separator
|
||||
return (completedText + currentBlockText).replace(/\n+$/, '')
|
||||
} catch (error) {
|
||||
if (rendererIsWatching) {
|
||||
sessionStreamBus.publish(session.id, {
|
||||
sessionId: session.id,
|
||||
agentId: session.agentId,
|
||||
type: 'error',
|
||||
error: { message: error instanceof Error ? error.message : String(error) }
|
||||
})
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async sendChunked(adapter: ChannelAdapter, chatId: string, text: string): Promise<void> {
|
||||
if (text.length <= MAX_MESSAGE_LENGTH) {
|
||||
await adapter.sendMessage(chatId, text)
|
||||
return
|
||||
if (!session.agentId) {
|
||||
throw new Error(`Cannot stream on orphan session ${session.id} — its agent was deleted`)
|
||||
}
|
||||
|
||||
const chunks = splitMessage(text, MAX_MESSAGE_LENGTH)
|
||||
for (const chunk of chunks) {
|
||||
await adapter.sendMessage(chatId, chunk)
|
||||
let resolveExecution!: (text: string) => void
|
||||
let rejectExecution!: (err: unknown) => void
|
||||
const executionDone = new Promise<string>((resolve, reject) => {
|
||||
resolveExecution = resolve
|
||||
rejectExecution = reject
|
||||
})
|
||||
let accumulatedText = ''
|
||||
const sentinel: StreamListener = {
|
||||
id: `channel-completion:${chatId}`,
|
||||
onChunk(chunk) {
|
||||
// `text-delta`'s field is `delta`, not `text` (AI SDK `UIMessageChunk`).
|
||||
if (chunk.type === 'text-delta') accumulatedText += chunk.delta
|
||||
},
|
||||
onDone() {
|
||||
resolveExecution(accumulatedText.trim())
|
||||
},
|
||||
onPaused() {
|
||||
resolveExecution(accumulatedText.trim())
|
||||
},
|
||||
onError(result) {
|
||||
rejectExecution(new Error(result.error.message ?? 'Execution failed'))
|
||||
},
|
||||
isAlive: () => !abortController.signal.aborted
|
||||
}
|
||||
|
||||
await startAgentSessionRun({
|
||||
sessionId: session.id,
|
||||
userParts: [{ type: 'text', text: content }],
|
||||
listeners: [sentinel, new ChannelAdapterListener(adapter, chatId)]
|
||||
})
|
||||
|
||||
return executionDone
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2,9 +2,11 @@ import { agentChannelService as channelService } from '@data/services/AgentChann
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ChannelAdapter, type ChannelAdapterConfig } from '../ChannelAdapter'
|
||||
import { channelManager, registerAdapterFactory } from '../ChannelManager'
|
||||
import { ChannelManager, registerAdapterFactory } from '../ChannelManager'
|
||||
import { channelMessageHandler } from '../ChannelMessageHandler'
|
||||
|
||||
const channelManager = new ChannelManager()
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })
|
||||
512
src/main/ai/channels/__tests__/ChannelMessageHandler.test.ts
Normal file
512
src/main/ai/channels/__tests__/ChannelMessageHandler.test.ts
Normal file
@@ -0,0 +1,512 @@
|
||||
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
|
||||
import { agentService } from '@data/services/AgentService'
|
||||
import { sessionService } from '@data/services/SessionService'
|
||||
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
|
||||
import { AgentSessionWorkspaceError } from '@main/ai/runtime/claudeCode/settingsBuilder'
|
||||
import { EventEmitter } from 'events'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { channelMessageHandler } from '../ChannelMessageHandler'
|
||||
import { sanitizeChannelOutput } from '../security'
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../security', () => ({
|
||||
wrapExternalContent: vi.fn((text: string) => text),
|
||||
sanitizeChannelOutput: vi.fn((text: string) => ({ text, redacted: false }))
|
||||
}))
|
||||
|
||||
// The global mock (tests/main.setup.ts) wires the default service set, which omits
|
||||
// AiStreamManager; the abort path reads it, so override locally with a captured spy.
|
||||
const { mockStreamAbort } = vi.hoisted(() => ({ mockStreamAbort: vi.fn() }))
|
||||
vi.mock('@application', async () => {
|
||||
const { mockApplicationFactory } = await import('@test-mocks/main/application')
|
||||
return mockApplicationFactory({ AiStreamManager: { abort: mockStreamAbort } } as never)
|
||||
})
|
||||
|
||||
vi.mock('@data/services/AgentService', () => ({
|
||||
agentService: {
|
||||
getAgent: vi.fn().mockResolvedValue({
|
||||
id: 'agent-1',
|
||||
configuration: {},
|
||||
model: 'openai::gpt-4'
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@data/services/SessionService', () => ({
|
||||
sessionService: {
|
||||
getById: vi.fn(),
|
||||
createSession: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@shared/data/types/model', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as any
|
||||
return {
|
||||
...actual,
|
||||
createUniqueModelId: vi.fn((providerId: string, modelId: string) => `${providerId}::${modelId}`)
|
||||
}
|
||||
})
|
||||
|
||||
const { mockStartAgentSessionRun } = vi.hoisted(() => ({ mockStartAgentSessionRun: vi.fn() }))
|
||||
vi.mock('@main/ai/streamManager/api/startAgentSessionRun', () => ({
|
||||
startAgentSessionRun: (...args: unknown[]) => mockStartAgentSessionRun(...args)
|
||||
}))
|
||||
|
||||
vi.mock('@data/services/AgentChannelService', () => ({
|
||||
agentChannelService: {
|
||||
getChannel: vi.fn().mockResolvedValue({ id: 'channel-1', sessionId: null, permissionMode: null }),
|
||||
updateChannel: vi.fn().mockResolvedValue(null),
|
||||
findBySessionId: vi.fn().mockResolvedValue(null)
|
||||
}
|
||||
}))
|
||||
|
||||
/**
|
||||
* Helper: configure mockStartAgentSessionRun to simulate streaming chunks to ALL
|
||||
* registered listeners (both the `channel-completion:` sentinel and the
|
||||
* `ChannelAdapterListener` that owns delivery), then call onDone on each so the
|
||||
* `executionDone` promise inside `collectStreamResponse` resolves and the listener
|
||||
* finalizes delivery. `text-delta` chunks carry the payload on `delta` (AI SDK
|
||||
* `UIMessageChunk`), not `text`.
|
||||
*/
|
||||
function simulateStream(parts: Array<{ type: string; delta?: string }>) {
|
||||
mockStartAgentSessionRun.mockImplementationOnce(
|
||||
async ({
|
||||
listeners
|
||||
}: {
|
||||
listeners: Array<{
|
||||
id: string
|
||||
onChunk: (chunk: unknown) => void
|
||||
onDone: (result: { status: string }) => void | Promise<void>
|
||||
}>
|
||||
}) => {
|
||||
for (const listener of listeners) {
|
||||
for (const part of parts) {
|
||||
listener.onChunk(part)
|
||||
}
|
||||
await listener.onDone({ status: 'success' })
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
function createMockAdapter(overrides: Record<string, unknown> = {}) {
|
||||
const adapter = new EventEmitter() as any
|
||||
adapter.agentId = overrides.agentId ?? 'agent-1'
|
||||
adapter.channelId = overrides.channelId ?? 'channel-1'
|
||||
adapter.channelType = overrides.channelType ?? 'telegram'
|
||||
adapter.connected = true
|
||||
adapter.sendMessage = vi.fn().mockResolvedValue(undefined)
|
||||
adapter.sendTypingIndicator = vi.fn().mockResolvedValue(undefined)
|
||||
adapter.onTextUpdate = vi.fn().mockResolvedValue(undefined)
|
||||
adapter.onStreamComplete = vi.fn().mockResolvedValue(false)
|
||||
adapter.onStreamError = vi.fn().mockResolvedValue(undefined)
|
||||
adapter.notifyChatIds = []
|
||||
return adapter
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper: call handleIncoming and advance fake timers so the debounce fires,
|
||||
* then await the returned promise to wait for processing to complete.
|
||||
*/
|
||||
async function handleIncomingAndFlush(
|
||||
adapter: ReturnType<typeof createMockAdapter>,
|
||||
message: { chatId: string; userId: string; userName: string; text: string }
|
||||
) {
|
||||
const promise = channelMessageHandler.handleIncoming(adapter, message)
|
||||
// Advance past the MESSAGE_BATCH_DELAY_MS debounce (10 000 ms)
|
||||
await vi.advanceTimersByTimeAsync(10500)
|
||||
return promise
|
||||
}
|
||||
|
||||
describe('ChannelMessageHandler', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
// Restore default agent mock after clearAllMocks
|
||||
vi.mocked(agentService.getAgent).mockResolvedValue({
|
||||
id: 'agent-1',
|
||||
configuration: {},
|
||||
model: 'openai::gpt-4'
|
||||
} as any)
|
||||
// Clear session tracker to ensure clean state
|
||||
channelMessageHandler.clearSessionTracker('agent-1')
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('collectStreamResponse accumulates text across turns and sends via adapter', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
simulateStream([
|
||||
{ type: 'text-delta', delta: 'Hello ' },
|
||||
{ type: 'text-delta', delta: 'world!' },
|
||||
{ type: 'text-end' },
|
||||
{ type: 'text-delta', delta: '\n\nDone.' },
|
||||
{ type: 'text-end' }
|
||||
])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
|
||||
// Delivery is owned by ChannelAdapterListener (the handler no longer post-sends);
|
||||
// it accumulates all text-delta chunks via `.delta`, trims, and sends once.
|
||||
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Hello world!\n\nDone.')
|
||||
})
|
||||
|
||||
// channels-core-3: the streaming delivery path (real ChannelAdapterListener) must route
|
||||
// output through the OutputSanitizer before sending — otherwise secrets in the model reply
|
||||
// leak to the IM platform. simulateStream drives the real listener, so a redacting sanitizer
|
||||
// must be reflected in what the adapter sends.
|
||||
it('routes channel output through the OutputSanitizer before delivery (REGRESSION channels-core-3)', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
|
||||
vi.mocked(sanitizeChannelOutput).mockImplementation((text: string) => ({
|
||||
text: text.replace('sk-SECRET', '<redacted>'),
|
||||
redacted: text.includes('sk-SECRET')
|
||||
}))
|
||||
simulateStream([{ type: 'text-delta', delta: 'the key is sk-SECRET' }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
|
||||
expect(sanitizeChannelOutput).toHaveBeenCalled()
|
||||
// The redacted text — not the raw secret — is what reaches the adapter.
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'the key is <redacted>')
|
||||
|
||||
// Restore the identity default so later tests are unaffected.
|
||||
vi.mocked(sanitizeChannelOutput).mockImplementation((text: string) => ({ text, redacted: false }))
|
||||
})
|
||||
|
||||
// stream-context-5: a workspace error is thrown before streaming starts, so onStreamError
|
||||
// (a no-op without a live controller on most adapters) can't surface it. The handler must
|
||||
// fall back to a plain sendMessage so the inbound message isn't silently dropped.
|
||||
it('surfaces a pre-stream workspace error as a plain message (REGRESSION stream-context-5)', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
mockStartAgentSessionRun.mockRejectedValueOnce(new AgentSessionWorkspaceError('workspace is missing'))
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'workspace is missing')
|
||||
expect(adapter.onStreamError).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips final send when adapter handles stream completion', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
adapter.onStreamComplete.mockResolvedValueOnce(true)
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
simulateStream([{ type: 'text-delta', delta: 'Hello world!' }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
|
||||
expect(adapter.onStreamComplete).toHaveBeenCalledWith('chat-1', 'Hello world!')
|
||||
expect(adapter.sendMessage).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('delivers a long response in a single send (platform splitting is the adapter concern)', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
|
||||
const longText = 'A'.repeat(5000)
|
||||
simulateStream([{ type: 'text-delta', delta: longText }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
|
||||
// The handler-level 4096-char chunking was dead code (post-hoc path never ran)
|
||||
// and has been removed; ChannelAdapterListener delivers the full text once and
|
||||
// each adapter splits per its own platform limit.
|
||||
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', longText)
|
||||
})
|
||||
|
||||
it('handleCommand /new creates a new session', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce({ id: 'new-session' } as any)
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'new'
|
||||
})
|
||||
|
||||
expect(sessionService.createSession).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
name: 'Channel session'
|
||||
})
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'New session created.')
|
||||
})
|
||||
|
||||
it('handleCommand /compact sends /compact as message content', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
|
||||
simulateStream([{ type: 'text-delta', delta: 'Compacted.' }])
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'compact'
|
||||
})
|
||||
|
||||
expect(mockStartAgentSessionRun).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
sessionId: 'session-1',
|
||||
userParts: [{ type: 'text', text: '/compact' }],
|
||||
listeners: expect.arrayContaining([
|
||||
expect.objectContaining({ id: expect.stringContaining('channel-completion:') })
|
||||
])
|
||||
})
|
||||
)
|
||||
// ChannelAdapterListener delivers the compact output once; the handler no longer
|
||||
// also sends it (would have been a double-send once the `.delta` read was fixed).
|
||||
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Compacted.')
|
||||
})
|
||||
|
||||
it('handleCommand /help sends help text with agent info', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
vi.mocked(agentService.getAgent).mockResolvedValueOnce({
|
||||
name: 'TestAgent',
|
||||
description: 'A test agent'
|
||||
} as any)
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'help'
|
||||
})
|
||||
|
||||
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
|
||||
const helpText = adapter.sendMessage.mock.calls[0][1] as string
|
||||
expect(helpText).toContain('*TestAgent*')
|
||||
expect(helpText).toContain('_A test agent_')
|
||||
expect(helpText).toContain('/new')
|
||||
expect(helpText).toContain('/compact')
|
||||
expect(helpText).toContain('/help')
|
||||
expect(helpText).toContain('/whoami')
|
||||
})
|
||||
|
||||
it('handleCommand /whoami sends the current chat ID', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'oc_123',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'whoami'
|
||||
})
|
||||
|
||||
expect(adapter.sendMessage).toHaveBeenCalledWith(
|
||||
'oc_123',
|
||||
'Current chat ID: `oc_123`\n\nAdd this value to `allow_ids` in settings to receive notifications.'
|
||||
)
|
||||
})
|
||||
|
||||
it('resolveSession tracks sessions after /new', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const newSession = {
|
||||
id: 'new-session',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(newSession as any)
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'new'
|
||||
})
|
||||
|
||||
// Now send a message — should use the tracked session
|
||||
vi.mocked(sessionService.getById).mockResolvedValueOnce(newSession as any)
|
||||
simulateStream([{ type: 'text-delta', delta: 'OK' }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'test'
|
||||
})
|
||||
|
||||
expect(sessionService.getById).toHaveBeenCalledWith('new-session')
|
||||
})
|
||||
|
||||
it('clearSessionTracker causes fresh session resolution', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
const session1 = {
|
||||
id: 'session-1',
|
||||
agentId: 'agent-1',
|
||||
agentType: 'claude-code',
|
||||
model: 'openai::gpt-4',
|
||||
workspace: { path: '/tmp/test-workspace' },
|
||||
configuration: {}
|
||||
}
|
||||
|
||||
// First interaction creates a session
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session1 as any)
|
||||
simulateStream([{ type: 'text-delta', delta: 'R1' }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'msg1'
|
||||
})
|
||||
|
||||
// Clear session tracker
|
||||
channelMessageHandler.clearSessionTracker('agent-1')
|
||||
|
||||
// Next interaction should find existing session via channel's session_id
|
||||
vi.mocked(channelService.getChannel).mockResolvedValueOnce({
|
||||
id: 'channel-1',
|
||||
sessionId: 'session-1',
|
||||
permissionMode: null
|
||||
} as any)
|
||||
vi.mocked(sessionService.getById).mockResolvedValueOnce(session1 as any)
|
||||
simulateStream([{ type: 'text-delta', delta: 'R2' }])
|
||||
|
||||
await handleIncomingAndFlush(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'msg2'
|
||||
})
|
||||
|
||||
// After clearing tracker, should look up channel then getSession instead of creating new session
|
||||
expect(channelService.getChannel).toHaveBeenCalledWith('channel-1')
|
||||
// Only 1 createSession call (the first one), not 2
|
||||
expect(sessionService.createSession).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// channels-core-3: discarding a pending (un-flushed) batch must settle its callers'
|
||||
// handleIncoming promises instead of leaving them hanging forever, so .catch fires.
|
||||
it('clearSessionTracker rejects pending-batch handleIncoming promises', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
|
||||
// Start a batch but do NOT advance timers — it stays pending in pendingBatches.
|
||||
const pending = channelMessageHandler.handleIncoming(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
text: 'Hi'
|
||||
})
|
||||
const rejection = expect(pending).rejects.toThrow('Agent removed; batch discarded')
|
||||
|
||||
// Clearing the agent's tracker discards the pending batch.
|
||||
channelMessageHandler.clearSessionTracker('agent-1')
|
||||
|
||||
await rejection
|
||||
expect(mockStartAgentSessionRun).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// channels-core-2: a local AbortController only flips a listener's isAlive() — clearing
|
||||
// a tracked session must stop the upstream agent-session turn via the manager.
|
||||
it('clearSessionTracker aborts the upstream agent-session turn via the manager', async () => {
|
||||
const adapter = createMockAdapter()
|
||||
vi.mocked(sessionService.createSession).mockResolvedValueOnce({ id: 'sess-x' } as any)
|
||||
|
||||
await channelMessageHandler.handleCommand(adapter, {
|
||||
chatId: 'chat-1',
|
||||
userId: 'user-1',
|
||||
userName: 'User',
|
||||
command: 'new'
|
||||
})
|
||||
mockStreamAbort.mockClear()
|
||||
|
||||
channelMessageHandler.clearSessionTracker('agent-1')
|
||||
|
||||
expect(mockStreamAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-x'), 'agent-cleared')
|
||||
})
|
||||
})
|
||||
79
src/main/ai/channels/adapters/__tests__/QqAdapter.test.ts
Normal file
79
src/main/ai/channels/adapters/__tests__/QqAdapter.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../ChannelManager', () => ({
|
||||
registerAdapterFactory: vi.fn()
|
||||
}))
|
||||
|
||||
const mockNetFetch = vi.fn()
|
||||
vi.mock('electron', () => ({
|
||||
app: { getPath: () => '/mock/userData' },
|
||||
net: { fetch: (...args: unknown[]) => mockNetFetch(...args) }
|
||||
}))
|
||||
|
||||
vi.mock('ws', () => {
|
||||
const Ctor = vi.fn()
|
||||
Object.assign(Ctor, { OPEN: 1, CONNECTING: 0, CLOSED: 3, CLOSING: 2 })
|
||||
return { default: Ctor, WebSocket: Ctor }
|
||||
})
|
||||
|
||||
import '../qq/QqAdapter'
|
||||
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
|
||||
// Capture the factory at module load — `registerAdapterFactory('qq', …)` runs once on import,
|
||||
// and afterEach's restoreAllMocks would otherwise wipe that call history before later tests.
|
||||
const qqCall = vi.mocked(registerAdapterFactory).mock.calls.find((c) => c[0] === 'qq')
|
||||
if (!qqCall) throw new Error('registerAdapterFactory was not called for qq')
|
||||
const qqFactory = qqCall[1] as (channel: any, agentId: string) => any
|
||||
|
||||
function mockBinaryResponse(buf: Buffer, contentType = 'image/png'): Response {
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers({ 'content-type': contentType }),
|
||||
arrayBuffer: () => Promise.resolve(buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength))
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
function createAdapter() {
|
||||
return qqFactory(
|
||||
{ id: 'ch-qq-1', type: 'qq', enabled: true, config: { app_id: 'app', client_secret: 'sec', allowed_chat_ids: [] } },
|
||||
'agent-1'
|
||||
)
|
||||
}
|
||||
|
||||
describe('QqAdapter.downloadAttachments', () => {
|
||||
beforeEach(() => mockNetFetch.mockReset())
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
it('rejects an SSRF target before any (token-bearing) fetch (C8)', async () => {
|
||||
const adapter = createAdapter()
|
||||
vi.spyOn(adapter, 'getAccessToken').mockResolvedValue('tok')
|
||||
|
||||
const result = await adapter.downloadAttachments([
|
||||
{ url: 'http://169.254.169.254/latest/meta-data/', content_type: 'image/png', filename: 'meta' }
|
||||
])
|
||||
|
||||
expect(result).toEqual({})
|
||||
expect(mockNetFetch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('downloads a public attachment URL', async () => {
|
||||
const adapter = createAdapter()
|
||||
vi.spyOn(adapter, 'getAccessToken').mockResolvedValue('tok')
|
||||
mockNetFetch.mockResolvedValue(mockBinaryResponse(Buffer.from('img'), 'image/png'))
|
||||
|
||||
const result = await adapter.downloadAttachments([
|
||||
{ url: 'https://gchat.qpic.cn/a.png', content_type: 'image/png', filename: 'a.png' }
|
||||
])
|
||||
|
||||
expect(result.images).toHaveLength(1)
|
||||
expect(mockNetFetch).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -349,6 +349,50 @@ describe('SlackAdapter', () => {
|
||||
expect(messageSpy.mock.calls[0][0].images).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('skips an oversized image download (content-length over the cap)', async () => {
|
||||
const adapter = await connectAdapter()
|
||||
const messageSpy = vi.fn()
|
||||
adapter.on('message', messageSpy)
|
||||
|
||||
const oversize = {
|
||||
ok: true,
|
||||
status: 200,
|
||||
headers: new Headers({ 'content-type': 'image/png', 'content-length': String(500 * 1024 * 1024) }),
|
||||
json: () => Promise.reject(new Error('not json')),
|
||||
text: () => Promise.resolve(''),
|
||||
arrayBuffer: () => Promise.resolve(Buffer.from('x').buffer)
|
||||
} as unknown as Response
|
||||
|
||||
mockNetFetch.mockImplementation((url: string) => {
|
||||
if (typeof url === 'string' && url.includes('files.slack.com')) {
|
||||
return Promise.resolve(oversize)
|
||||
}
|
||||
if (url.includes('users.info')) {
|
||||
return Promise.resolve(mockJsonResponse({ ok: true, user: { real_name: 'Test User' } }))
|
||||
}
|
||||
return Promise.resolve(mockJsonResponse({ ok: true }))
|
||||
})
|
||||
|
||||
simulateMessageEvent({
|
||||
channel: 'C0ALLOWED',
|
||||
user: USER1_ID,
|
||||
text: 'see attached',
|
||||
subtype: 'file_share',
|
||||
files: [
|
||||
{
|
||||
id: 'F1',
|
||||
name: 'huge.png',
|
||||
mimetype: 'image/png',
|
||||
size: 1000,
|
||||
url_private: 'https://files.slack.com/huge.png'
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
await vi.waitFor(() => expect(messageSpy).toHaveBeenCalledTimes(1))
|
||||
expect(messageSpy.mock.calls[0][0].images).toBeUndefined()
|
||||
})
|
||||
|
||||
it('ignores messages from the bot itself', async () => {
|
||||
const adapter = await connectAdapter()
|
||||
const messageSpy = vi.fn()
|
||||
@@ -99,6 +99,55 @@ describe('TelegramAdapter', () => {
|
||||
expect(mockBot.stop).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// channel-adapters-2: grammY rethrows a fatal 409/Conflict out of bot.start(); the adapter
|
||||
// must reconnect with backoff instead of staying permanently down.
|
||||
it('reconnects with backoff when polling rejects (REGRESSION channel-adapters-2)', async () => {
|
||||
vi.useFakeTimers()
|
||||
const adapter = createAdapter()
|
||||
mockBot.start.mockReset()
|
||||
// First polling attempt fails (recoverable 409); the reconnect attempt succeeds.
|
||||
mockBot.start.mockRejectedValueOnce(new Error('409: Conflict')).mockResolvedValue(undefined)
|
||||
|
||||
await adapter.connect()
|
||||
await vi.advanceTimersByTimeAsync(0) // let the rejection handler schedule the reconnect
|
||||
expect(mockBot.start).toHaveBeenCalledTimes(1)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000) // first backoff delay
|
||||
expect(mockBot.start).toHaveBeenCalledTimes(2) // reconnected
|
||||
})
|
||||
|
||||
it('resets the reconnect budget after a stable polling window (REGRESSION channel-adapters-2)', async () => {
|
||||
vi.useFakeTimers()
|
||||
const adapter = createAdapter()
|
||||
mockBot.start.mockReset()
|
||||
// One transient failure bumps the attempt counter, then the reconnect stays up.
|
||||
mockBot.start.mockRejectedValueOnce(new Error('409: Conflict')).mockResolvedValue(undefined)
|
||||
|
||||
await adapter.connect()
|
||||
await vi.advanceTimersByTimeAsync(1000) // reconnect fires and succeeds
|
||||
expect(adapter.reconnectAttempts).toBe(1)
|
||||
|
||||
// After the stability window the counter resets, so lifetime-cumulative transient
|
||||
// failures can't monotonically exhaust maxReconnectAttempts.
|
||||
await vi.advanceTimersByTimeAsync(60_000)
|
||||
expect(adapter.reconnectAttempts).toBe(0)
|
||||
})
|
||||
|
||||
it('does not reconnect after disconnect() (REGRESSION channel-adapters-2)', async () => {
|
||||
vi.useFakeTimers()
|
||||
const adapter = createAdapter()
|
||||
mockBot.start.mockReset()
|
||||
mockBot.start.mockRejectedValue(new Error('409: Conflict'))
|
||||
|
||||
await adapter.connect()
|
||||
await vi.advanceTimersByTimeAsync(0) // a reconnect is now pending
|
||||
await adapter.disconnect() // shouldStop + clear the pending reconnect timer
|
||||
|
||||
const callsAfterDisconnect = mockBot.start.mock.calls.length
|
||||
await vi.advanceTimersByTimeAsync(60_000)
|
||||
expect(mockBot.start.mock.calls.length).toBe(callsAfterDisconnect) // no further reconnect
|
||||
})
|
||||
|
||||
it('sendMessage() sends text with MarkdownV2 by default', async () => {
|
||||
const adapter = createAdapter()
|
||||
await adapter.connect()
|
||||
@@ -1,16 +1,14 @@
|
||||
import { net } from 'electron'
|
||||
import WebSocket from 'ws'
|
||||
|
||||
import {
|
||||
ChannelAdapter,
|
||||
type ChannelAdapterConfig,
|
||||
downloadFileAsBase64,
|
||||
downloadImageAsBase64,
|
||||
type FileAttachment,
|
||||
type ImageAttachment,
|
||||
MAX_FILE_SIZE_BYTES,
|
||||
type SendMessageOptions
|
||||
} from '../../ChannelAdapter'
|
||||
MAX_FILE_SIZE_BYTES
|
||||
} from '@main/utils/downloadAsBase64'
|
||||
import { net } from 'electron'
|
||||
import WebSocket from 'ws'
|
||||
|
||||
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
import { isSlashCommand, SLASH_COMMANDS } from '../../constants'
|
||||
import { FlushController } from '../../FlushController'
|
||||
@@ -234,12 +232,11 @@ class DiscordAdapter extends ChannelAdapter {
|
||||
/** Per-chat streaming controller. One stream at a time per chat. */
|
||||
private readonly streamingControllers = new Map<string, DiscordStreamingController>()
|
||||
|
||||
constructor(config: ChannelAdapterConfig) {
|
||||
constructor(config: ChannelAdapterConfig<'discord'>) {
|
||||
super(config)
|
||||
const { bot_token, allowed_channel_ids } = config.channelConfig
|
||||
this.botToken = (bot_token as string) ?? ''
|
||||
const rawIds = allowed_channel_ids as string[] | undefined
|
||||
this.allowedChannelIds = Array.isArray(rawIds) ? rawIds.map(String) : []
|
||||
this.botToken = bot_token
|
||||
this.allowedChannelIds = allowed_channel_ids ?? []
|
||||
this.notifyChatIds = [...this.allowedChannelIds]
|
||||
}
|
||||
|
||||
@@ -4,17 +4,11 @@ import type { ReadableStream as NodeReadableStream } from 'node:stream/web'
|
||||
import { application } from '@application'
|
||||
import * as Lark from '@larksuiteoapi/node-sdk'
|
||||
import { WindowType } from '@main/core/window/types'
|
||||
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
|
||||
import type { FeishuDomain } from '@shared/data/types/channel'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
|
||||
import {
|
||||
ChannelAdapter,
|
||||
type ChannelAdapterConfig,
|
||||
type FileAttachment,
|
||||
type ImageAttachment,
|
||||
MAX_FILE_SIZE_BYTES,
|
||||
type SendMessageOptions
|
||||
} from '../../ChannelAdapter'
|
||||
import type { FeishuDomain } from '../../channelConfig'
|
||||
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
import { isSlashCommand } from '../../constants'
|
||||
import { FlushController } from '../../FlushController'
|
||||
@@ -476,16 +470,15 @@ class FeishuAdapter extends ChannelAdapter {
|
||||
/** Active status reaction per chat, so we can swap or remove it. */
|
||||
private readonly chatReactions = new Map<string, ChatReaction>()
|
||||
|
||||
constructor(config: ChannelAdapterConfig) {
|
||||
constructor(config: ChannelAdapterConfig<'feishu'>) {
|
||||
super(config)
|
||||
const { app_id, app_secret, encrypt_key, verification_token, allowed_chat_ids, domain } = config.channelConfig
|
||||
this.appId = (app_id as string) ?? ''
|
||||
this.appSecret = (app_secret as string) ?? ''
|
||||
this.encryptKey = (encrypt_key as string) ?? ''
|
||||
this.verificationToken = (verification_token as string) ?? ''
|
||||
const rawIds = allowed_chat_ids as string[] | undefined
|
||||
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
|
||||
this.domain = ((domain as string) ?? 'feishu') as FeishuDomain
|
||||
this.appId = app_id
|
||||
this.appSecret = app_secret
|
||||
this.encryptKey = encrypt_key
|
||||
this.verificationToken = verification_token
|
||||
this.allowedChatIds = allowed_chat_ids ?? []
|
||||
this.domain = domain
|
||||
this.notifyChatIds = [...this.allowedChatIds]
|
||||
}
|
||||
|
||||
@@ -7,10 +7,9 @@
|
||||
* Flow: init -> begin (returns QR URL) -> poll (returns client_id + client_secret)
|
||||
*/
|
||||
import { loggerService } from '@logger'
|
||||
import type { FeishuDomain } from '@shared/data/types/channel'
|
||||
import { net } from 'electron'
|
||||
|
||||
import type { FeishuDomain } from '../../channelConfig'
|
||||
|
||||
const logger = loggerService.withContext('FeishuAppRegistration')
|
||||
|
||||
const BASE_URLS: Record<FeishuDomain, string> = {
|
||||
@@ -1,14 +1,9 @@
|
||||
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
|
||||
import { sanitizeRemoteUrl } from '@main/utils/remoteUrlSafety'
|
||||
import { net } from 'electron'
|
||||
import WebSocket from 'ws'
|
||||
|
||||
import {
|
||||
ChannelAdapter,
|
||||
type ChannelAdapterConfig,
|
||||
type FileAttachment,
|
||||
type ImageAttachment,
|
||||
MAX_FILE_SIZE_BYTES,
|
||||
type SendMessageOptions
|
||||
} from '../../ChannelAdapter'
|
||||
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
import { isSlashCommand } from '../../constants'
|
||||
import { splitMessage } from '../../utils'
|
||||
@@ -88,13 +83,12 @@ class QqAdapter extends ChannelAdapter {
|
||||
/** Number of rapid disconnects before invalidating session */
|
||||
private readonly maxRapidDisconnects = 3
|
||||
|
||||
constructor(config: ChannelAdapterConfig) {
|
||||
constructor(config: ChannelAdapterConfig<'qq'>) {
|
||||
super(config)
|
||||
const { app_id, client_secret, allowed_chat_ids } = config.channelConfig
|
||||
this.appId = (app_id as string) ?? ''
|
||||
this.clientSecret = (client_secret as string) ?? ''
|
||||
const rawIds = allowed_chat_ids as string[] | undefined
|
||||
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
|
||||
this.appId = app_id
|
||||
this.clientSecret = client_secret
|
||||
this.allowedChatIds = allowed_chat_ids ?? []
|
||||
this.notifyChatIds = [...this.allowedChatIds]
|
||||
}
|
||||
|
||||
@@ -240,7 +234,7 @@ class QqAdapter extends ChannelAdapter {
|
||||
return
|
||||
}
|
||||
|
||||
if (payload.s !== undefined) {
|
||||
if (payload.s !== undefined && payload.s !== null) {
|
||||
this.lastSeq = payload.s
|
||||
}
|
||||
|
||||
@@ -427,17 +421,23 @@ class QqAdapter extends ChannelAdapter {
|
||||
.map(async (att) => {
|
||||
try {
|
||||
const url = att.url.startsWith('http') ? att.url : `https://${att.url}`
|
||||
const response = await net.fetch(url, {
|
||||
// SSRF guard: reject local/private/credentialed/non-http(s) targets from the
|
||||
// inbound payload before we fetch with the bot token (and before the retry).
|
||||
const safeUrl = sanitizeRemoteUrl(url)
|
||||
const response = await net.fetch(safeUrl, {
|
||||
headers: { Authorization: `QQBot ${token}`, 'X-Union-Appid': this.appId }
|
||||
})
|
||||
if (!response.ok) {
|
||||
// Retry without auth header (some CDN URLs are public)
|
||||
const retry = await net.fetch(url)
|
||||
const retry = await net.fetch(safeUrl)
|
||||
if (!retry.ok) return
|
||||
const buffer = Buffer.from(await retry.arrayBuffer())
|
||||
// `att.size` is attacker-supplied metadata; cap on the real downloaded bytes.
|
||||
if (buffer.length > MAX_FILE_SIZE_BYTES) return
|
||||
this.pushAttachment(att, buffer, images, files)
|
||||
} else {
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
if (buffer.length > MAX_FILE_SIZE_BYTES) return
|
||||
this.pushAttachment(att, buffer, images, files)
|
||||
}
|
||||
} catch {
|
||||
@@ -1,14 +1,8 @@
|
||||
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
|
||||
import { net } from 'electron'
|
||||
import WebSocket from 'ws'
|
||||
|
||||
import {
|
||||
ChannelAdapter,
|
||||
type ChannelAdapterConfig,
|
||||
type FileAttachment,
|
||||
type ImageAttachment,
|
||||
MAX_FILE_SIZE_BYTES,
|
||||
type SendMessageOptions
|
||||
} from '../../ChannelAdapter'
|
||||
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
import { isSlashCommand } from '../../constants'
|
||||
import { FlushController } from '../../FlushController'
|
||||
@@ -211,13 +205,12 @@ class SlackAdapter extends ChannelAdapter {
|
||||
/** Track the latest incoming message ts per chatId for reaction acknowledgment */
|
||||
private readonly pendingReactions = new Map<string, string>()
|
||||
|
||||
constructor(config: ChannelAdapterConfig) {
|
||||
constructor(config: ChannelAdapterConfig<'slack'>) {
|
||||
super(config)
|
||||
const { bot_token, app_token, allowed_channel_ids } = config.channelConfig
|
||||
this.botToken = (bot_token as string) ?? ''
|
||||
this.appToken = (app_token as string) ?? ''
|
||||
const rawIds = allowed_channel_ids as string[] | undefined
|
||||
this.allowedChannelIds = Array.isArray(rawIds) ? rawIds.map(String) : []
|
||||
this.botToken = bot_token
|
||||
this.appToken = app_token
|
||||
this.allowedChannelIds = allowed_channel_ids ?? []
|
||||
this.notifyChatIds = [...this.allowedChannelIds]
|
||||
}
|
||||
|
||||
@@ -483,9 +476,12 @@ class SlackAdapter extends ChannelAdapter {
|
||||
headers: { Authorization: `Bearer ${this.botToken}` }
|
||||
})
|
||||
if (!response.ok) return null
|
||||
const contentLength = response.headers.get('content-length')
|
||||
if (contentLength && Number.parseInt(contentLength, 10) > MAX_FILE_SIZE_BYTES) return null
|
||||
const contentType = response.headers.get('content-type') || 'image/png'
|
||||
const mediaType = contentType.split(';')[0].trim()
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
if (buffer.length > MAX_FILE_SIZE_BYTES) return null
|
||||
return { data: buffer.toString('base64'), media_type: mediaType }
|
||||
} catch {
|
||||
return null
|
||||
@@ -1,19 +1,26 @@
|
||||
import { Bot } from 'grammy'
|
||||
import { convert as toMarkdownV2 } from 'telegram-markdown-v2'
|
||||
|
||||
import {
|
||||
ChannelAdapter,
|
||||
type ChannelAdapterConfig,
|
||||
downloadFileAsBase64,
|
||||
downloadImageAsBase64,
|
||||
type FileAttachment,
|
||||
type ImageAttachment,
|
||||
MAX_FILE_SIZE_BYTES,
|
||||
type SendMessageOptions
|
||||
} from '../../ChannelAdapter'
|
||||
MAX_FILE_SIZE_BYTES
|
||||
} from '@main/utils/downloadAsBase64'
|
||||
import { Bot } from 'grammy'
|
||||
import { convert as toMarkdownV2 } from 'telegram-markdown-v2'
|
||||
|
||||
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
|
||||
import { registerAdapterFactory } from '../../ChannelManager'
|
||||
|
||||
const TELEGRAM_MAX_LENGTH = 4096
|
||||
/**
|
||||
* Plain-text chunk budget under MarkdownV2. We split the *plain* text (so each
|
||||
* chunk has an index-aligned plain fallback) and then escape it; escaping only
|
||||
* grows length, so this headroom keeps the formatted chunk within the 4096 hard
|
||||
* limit for normal prose. A pathological all-special-char chunk could still
|
||||
* overflow — Telegram then rejects it and the catch sends the plain chunk, which
|
||||
* is always within budget.
|
||||
*/
|
||||
const TELEGRAM_MARKDOWN_CHUNK_BUDGET = 3200
|
||||
|
||||
import { splitMessage } from '../../utils'
|
||||
|
||||
@@ -22,12 +29,26 @@ class TelegramAdapter extends ChannelAdapter {
|
||||
private readonly botToken: string
|
||||
private readonly allowedChatIds: string[]
|
||||
|
||||
constructor(config: ChannelAdapterConfig) {
|
||||
// Long-polling reconnect with backoff. grammY rethrows fatal polling errors (401/409) out of
|
||||
// `bot.start()`; without this a recoverable 409/Conflict left the bot permanently down.
|
||||
// Mirrors the WebSocket adapters (Slack/Discord/QQ).
|
||||
private shouldStop = false
|
||||
private reconnectAttempts = 0
|
||||
private reconnectTimer: ReturnType<typeof setTimeout> | null = null
|
||||
private stabilityTimer: ReturnType<typeof setTimeout> | null = null
|
||||
private readonly reconnectDelays = [1000, 2000, 5000, 10000, 30000, 60000]
|
||||
private readonly maxReconnectAttempts = 50
|
||||
// Long polling has no "ready" event (a fatal 409 surfaces *after* `markConnected`), so
|
||||
// resetting the backoff budget on connect would let a persistent failure loop forever and
|
||||
// never hit the cap. Instead reset only after the bot has polled cleanly for this window —
|
||||
// so transient failures spread over the adapter's lifetime don't monotonically exhaust it.
|
||||
private readonly stabilityResetMs = 60_000
|
||||
|
||||
constructor(config: ChannelAdapterConfig<'telegram'>) {
|
||||
super(config)
|
||||
const { bot_token, allowed_chat_ids } = config.channelConfig
|
||||
this.botToken = (bot_token as string) ?? ''
|
||||
const rawIds = allowed_chat_ids as string[] | undefined
|
||||
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
|
||||
this.botToken = bot_token
|
||||
this.allowedChatIds = allowed_chat_ids ?? []
|
||||
this.notifyChatIds = [...this.allowedChatIds]
|
||||
}
|
||||
|
||||
@@ -39,7 +60,12 @@ class TelegramAdapter extends ChannelAdapter {
|
||||
if (!this.botToken) {
|
||||
throw new Error('Telegram bot token is required')
|
||||
}
|
||||
this.shouldStop = false
|
||||
this.reconnectAttempts = 0
|
||||
await this.startBot()
|
||||
}
|
||||
|
||||
private async startBot(): Promise<void> {
|
||||
const bot = new Bot(this.botToken)
|
||||
this.bot = bot
|
||||
|
||||
@@ -161,18 +187,70 @@ class TelegramAdapter extends ChannelAdapter {
|
||||
this.log.error(`Bot error: ${msg}`)
|
||||
})
|
||||
|
||||
// Start long polling (fire-and-forget)
|
||||
// Start long polling (fire-and-forget). `bot.start()` only resolves when the bot stops;
|
||||
// a fatal polling error (e.g. 409 Conflict) rejects here — schedule a backoff reconnect.
|
||||
bot.start().catch((err) => {
|
||||
const msg = err instanceof Error ? err.message : String(err)
|
||||
this.clearStabilityTimer()
|
||||
this.markDisconnected(msg)
|
||||
this.log.error(`Polling stopped: ${msg}`)
|
||||
this.scheduleReconnect()
|
||||
})
|
||||
|
||||
this.markConnected()
|
||||
this.log.info('Telegram bot polling started')
|
||||
|
||||
// Reset the reconnect budget once this connection has stayed up for the stability window.
|
||||
// The `this.bot === bot` guard ensures a stale timer from a superseded connection no-ops.
|
||||
this.clearStabilityTimer()
|
||||
this.stabilityTimer = setTimeout(() => {
|
||||
this.stabilityTimer = null
|
||||
if (!this.shouldStop && this.bot === bot) this.reconnectAttempts = 0
|
||||
}, this.stabilityResetMs)
|
||||
}
|
||||
|
||||
private clearStabilityTimer(): void {
|
||||
if (this.stabilityTimer) {
|
||||
clearTimeout(this.stabilityTimer)
|
||||
this.stabilityTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
private scheduleReconnect(): void {
|
||||
if (this.shouldStop || this.reconnectTimer) return
|
||||
|
||||
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
|
||||
this.log.error('Telegram max reconnect attempts reached, giving up')
|
||||
return
|
||||
}
|
||||
|
||||
const delay = this.reconnectDelays[Math.min(this.reconnectAttempts, this.reconnectDelays.length - 1)]
|
||||
this.reconnectAttempts++
|
||||
this.log.info(`Telegram reconnecting in ${delay}ms (attempt ${this.reconnectAttempts})`)
|
||||
|
||||
this.reconnectTimer = setTimeout(() => {
|
||||
this.reconnectTimer = null
|
||||
if (this.shouldStop) return
|
||||
this.startBot().catch((err) => {
|
||||
const msg = err instanceof Error ? err.message : String(err)
|
||||
this.markDisconnected(msg)
|
||||
this.log.error(`Telegram reconnect failed: ${msg}`)
|
||||
this.scheduleReconnect()
|
||||
})
|
||||
}, delay)
|
||||
}
|
||||
|
||||
private clearReconnectTimer(): void {
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer)
|
||||
this.reconnectTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
protected override async performDisconnect(): Promise<void> {
|
||||
this.shouldStop = true
|
||||
this.clearReconnectTimer()
|
||||
this.clearStabilityTimer()
|
||||
if (this.bot) {
|
||||
await this.bot.stop()
|
||||
this.bot = null
|
||||
@@ -228,33 +306,39 @@ class TelegramAdapter extends ChannelAdapter {
|
||||
}
|
||||
|
||||
const parseMode = opts?.parseMode ?? 'MarkdownV2'
|
||||
const formatted = parseMode === 'MarkdownV2' ? toMarkdownV2(text).trimEnd() : text
|
||||
const chunks = splitMessage(formatted, TELEGRAM_MAX_LENGTH)
|
||||
const isMarkdown = parseMode === 'MarkdownV2'
|
||||
// Split the PLAIN text first and escape each chunk, so the MarkdownV2 send and
|
||||
// its plain-text fallback share one chunk boundary. (Splitting the *formatted*
|
||||
// text and then re-splitting the *raw* text by the same index misaligns — escaping
|
||||
// changes lengths/boundaries — dropping, duplicating, or passing `undefined` chunks.)
|
||||
const plainChunks = splitMessage(text, isMarkdown ? TELEGRAM_MARKDOWN_CHUNK_BUDGET : TELEGRAM_MAX_LENGTH)
|
||||
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
for (let i = 0; i < plainChunks.length; i++) {
|
||||
const plain = plainChunks[i]
|
||||
const formatted = isMarkdown ? toMarkdownV2(plain).trimEnd() : plain
|
||||
const replyParams =
|
||||
opts?.replyToMessageId && i === 0 ? { reply_parameters: { message_id: opts.replyToMessageId } } : {}
|
||||
|
||||
try {
|
||||
await this.bot.api.sendMessage(chatId, chunks[i], {
|
||||
await this.bot.api.sendMessage(chatId, formatted, {
|
||||
parse_mode: parseMode,
|
||||
...replyParams
|
||||
})
|
||||
} catch (error) {
|
||||
// Fallback to plain text if MarkdownV2 parsing fails
|
||||
if (parseMode === 'MarkdownV2') {
|
||||
// Fallback to plain text if MarkdownV2 parsing fails — same chunk content.
|
||||
if (isMarkdown) {
|
||||
this.log.warn('MarkdownV2 send failed, falling back to plain text', {
|
||||
chatId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
await this.bot.api.sendMessage(chatId, splitMessage(text, TELEGRAM_MAX_LENGTH)[i], replyParams)
|
||||
await this.bot.api.sendMessage(chatId, plain, replyParams)
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay between chunks to avoid rate limiting
|
||||
if (i < chunks.length - 1) {
|
||||
if (i < plainChunks.length - 1) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user