refactor(ai-service): consolidate AI runtime to main process (#14911)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: fullex <106392080+0xfullex@users.noreply.github.com>
Signed-off-by: suyao <sy20010504@gmail.com>
This commit is contained in:
SuYao
2026-06-05 00:06:51 +08:00
committed by GitHub
parent ad922067d4
commit 5706307451
1418 changed files with 76141 additions and 79724 deletions

View File

@@ -0,0 +1,5 @@
---
'@cherrystudio/ai-core': patch
---
Remove the prompt-based tool-use plugin end-to-end. Tool use now relies solely on native provider tool calling, so `promptToolUsePlugin` (with its `StreamEventManager`, `ToolExecutor`, and tag-extraction helpers) and the public exports `ToolUseRequestContext` and `AiRequestMetadata.isPromptToolUse` are gone. Also switch the provider cache from `lru-cache` to `quick-lru`.

1
.gitignore vendored
View File

@@ -85,3 +85,4 @@ test-results
YOUR_MEMORY_FILE_PATH
.sessions/
.devtools

View File

@@ -24,7 +24,10 @@
| Document | Description |
|----------|-------------|
| [AI Core Architecture](./references/ai-core-architecture.md) | Complete data flow and architecture from user input to LLM response |
| [AI Reference](./references/ai/README.md) | Main-process AI pipeline: stream manager, agent loop, providers, tools |
| [Core Architecture](./references/ai/core-architecture.md) | End-to-end call flow from user input to LLM response |
| [Stream Manager](./references/ai/stream-manager.md) | Active-stream registry, broker, reconnect, persistence |
| [Adapter Family](./references/ai/adapter-family.md) | How endpoint → `@ai-sdk/*` package routing is decided |
### Data System

BIN
docs/image-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

BIN
docs/image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 125 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,152 @@
# AI Reference
This is the entry point for the AI pipeline in Cherry Studio v2 — the
main-process service that owns every LLM call (chat streams, agent loops,
translate, summarisation) and the renderer-side transport that connects to it.
## Quick navigation
### Top-level architecture
| Document | What it covers |
|---|---|
| [Core Architecture](./core-architecture.md) | End-to-end call flow: `Ai_Stream_Open` IPC → context provider → AiStreamManager → Agent loop → `@ai-sdk/*` → broadcast / persist |
| [Stream Manager](./stream-manager.md) | Active-stream registry, listeners, reconnect, abort, abort-and-restart steering, persistence backends |
| [Agent Session Runtime](./agent-session-runtime.md) | Agent-session host/driver split, `pendingTurns` follow-up queue, resume token persistence, Claude Code driver fallback |
| [Adapter Family](./adapter-family.md) | How `provider.endpointConfigs[ep].adapterFamily` picks the right `@ai-sdk/*` package per request |
### Subsystems
| Document | What it covers |
|---|---|
| [Agent Loop](./agent-loop.md) | Main-process `Agent.stream()`: single-pass stream, hook composition, observer pattern, error/abort semantics |
| [Params Pipeline](./params-pipeline.md) | `buildAgentParams` + `RequestFeature` model: how capabilities, plugins, tools, and provider-specific quirks are composed |
| [Tool Registry](./tool-registry.md) | Built-in tools (knowledge / web search), MCP tools, meta-tools (`tool_search` / `tool_inspect` / `tool_invoke` / `tool_exec`), deferred exposition |
| [Provider Resolution](./provider-resolution.md) | `Provider.endpointConfigs` schema, endpoint resolution chain, variant suffixes, custom provider extensions (aihubmix, newapi) |
| [Observability (trace / telemetry)](./observability.md) | `AiSdkSpanAdapter`, root span propagation, OTel attribute shape, local span projection, sinks |
### Renderer-side glue
| Document | What it covers |
|---|---|
| [IPC Transport](./ipc-transport.md) | `useChat` + `IpcChatTransport`: `sendMessages` / `reconnectToStream`, dispatch coordinator, topic-status mirror |
| [Execution Overlay](./execution-overlay.md) | `TopicStreamSubscription` + `useExecutionOverlay`: ref-counted attach, per-execution demux, one-shot `readUIMessageStream` per turn (the renderer half of the same merge function Main uses) |
| [Tool Approval](./tool-approval.md) | Approval registry, Main-as-writer model, persistent decisions, `useToolApproval` hook |
## Where the code lives
> **Scope of the focused docs.** The reference documents in this folder map
> the **chat / stream pipeline** (dispatch → stream manager → runtime →
> tools → persistence → renderer transport). The `agents/`, `channels/`,
> `skills/`, and `mcp/` subsystems are mapped in the tree below but do not
> yet have dedicated deep-dive docs.
```
src/main/ai/
├── AiService.ts ← lifecycle owner, IPC handlers (generate / translate / approval)
├── runtime/ ← AI execution backends + runtime registry
│ ├── aiSdk/ ← Agent class, loop, observers, params/features, prompts/
│ └── claudeCode/ ← Claude Code driver, warm query, SDK adapter
├── agentSession/ ← agent-session topic host
│ └── AgentSessionRuntimeService.ts
├── agents/ ← AgentJobsService, AgentTaskJobHandler, runAgentTask, builtin/, cherryclaw/
├── channels/ ← ChannelManager + IM adapters (discord/feishu/qq/slack/telegram/wechat) + security/
├── streamManager/ ← AiStreamManager + listeners + persistence backends
│ ├── AiStreamManager.ts ← registers the stream IPC (Open/Attach/Detach/Abort)
│ ├── context/ ← ChatContextProvider implementations + dispatch
│ ├── lifecycle/ ← chat / prompt-only stream lifecycles
│ ├── listeners/ ← WebContents / Persistence / SSE / channel-adapter
│ ├── persistence/ ← MessageService / TemporaryChat / Translation backends
│ └── pipeStreamLoop.ts ← shared chunk-pipe primitive
├── provider/ ← provider config, endpoint resolution, custom providers
│ ├── custom/ ← aihubmix, newapi
│ ├── config.ts ← providerToAiSdkConfig (builder table)
│ ├── endpoint.ts ← resolveEffectiveEndpoint + adapterFamily routing
│ ├── extensions/ ← ProviderExtension registrations
│ └── listModels.ts ← per-provider model listing
├── mcp/ ← McpRuntimeService / McpCatalogService, oauth/, built-in servers
│ └── servers/ ← in-memory MCP server implementations (browser, filesystem)
├── skills/ ← SkillService, SkillInstaller
├── tools/ ← unified tool registry
│ └── adapters/
│ ├── aiSdk/ ← registry.ts, repair.ts; builtin/ (web__search/web__fetch/kb__*),
│ │ mcp/ (server → ToolEntry sync), meta/ (tool_search/inspect/invoke;
│ │ tool_exec defined but not injected), exposition/ (shouldDefer + applyDefer)
│ └── claudeCode/ ← agentTools.ts (registry → Claude Code runtime)
├── observability/ ← AI trace adapters (aiSdk / claudeCode), local projection, sinks
├── messages/ ← UI part → AI SDK part conversion
├── types/ ← AppProviderId, merged extension types, request types
└── utils/ ← reasoning / model parameters / options / websearch helpers
```
## How a chat turn flows
1. Renderer `useChat({ transport: IpcChatTransport })` calls `sendMessages`
IPC `Ai_Stream_Open` (`{ topicId, trigger, userMessageParts, parentAnchorId?, mentionedModelIds? }`).
2. `AiStreamManager.onInit` registered the `Ai_Stream_Open` handler; it
wraps the sender in a `WebContentsListener` and calls
`dispatchStreamRequest(manager, subscriber, req)`. (The stream IPC —
`Open`/`Attach`/`Detach`/`Abort` — lives on `AiStreamManager`, not
`AiService`.)
3. `dispatchStreamRequest` picks the first `ChatContextProvider` whose
`canHandle(topicId)` matches (persistent chat / temporary / agent
session) and calls `prepareDispatch` — that resolves models, persists
the user message, builds listeners, and returns a `PreparedDispatch`.
4. `AiStreamManager.send(input)` **starts** a turn (no active stream): creates
an `ActiveStream`, launches one `StreamExecution` per model. (A chat
resubmit on a live topic is restarted upstream — `dispatch` calls
`abortAndAwait` first; only an agent-session follow-up takes the
**inject** path, which just upserts listeners.)
5. Each execution's `runExecutionLoop` calls `AiService.streamText(request,
signal)`, which builds params (`buildAgentParams`) and constructs an `Agent`
composing hooks from `RequestFeature[]` (anthropic cache, gateway usage
normalisation, reasoning extraction, …), then calls `agent.stream(messages,
signal)` to open the AI SDK stream and yield `UIMessageChunk`s.
Agent-session runtime requests are the exception: `AiService.streamText`
routes them to `AgentSessionRuntimeService.openTurnStream()` so the
registered driver can own the concrete agent runtime.
6. `pipeStreamLoop` tees the chunk stream: one branch broadcasts to listeners
(WebContents / SSE / channel-adapter / persistence), one branch runs
`readUIMessageStream` to accumulate a `CherryUIMessage` snapshot.
7. On terminal (done / error / aborted / paused-for-approval), listeners get
a typed terminal callback. `PersistenceListener` writes the final
message via the appropriate `PersistenceBackend`.
8. Renderer reads the persisted row through `useQuery('/topics/:id/messages')`
and disposes its overlay.
## Key invariants
- **Topic-level addressing.** Every IPC and broadcast is keyed by `topicId`.
A topic has at most one active stream; subscribers are equal — there's no
"owner" window.
- **Main owns persistence.** Renderer closing or crashing does not abort the
stream and does not lose data — `PersistenceListener` writes on terminal
regardless of who is listening.
- **Tool approval is Main-authoritative.** The renderer never writes
`approved`/`denied` parts. It posts the decision over IPC and re-reads the
authoritative row. See [Tool Approval](./tool-approval.md).
- **Adapter family per endpoint, not per provider.** Multi-endpoint relays
(MiniMax, Silicon, AiHubMix, …) carry one `adapterFamily` per endpoint.
Picking the SDK package never reads `apiHost` or provider id heuristics
at request time. See [Adapter Family](./adapter-family.md).
## Related references
- [Service Lifecycle](../lifecycle/README.md) — `AiService` extends `BaseService`
- [Data Layer](../data/README.md) — `MessageService`, `ModelService`,
`ProviderService` (called from main-side AI code)
- [Messaging](../messaging/message-system.md) — `CherryMessagePart`,
`CherryUIMessage`, parts model
- [Window Manager](../window-manager/README.md) — `WebContentsListener`
attaches to whatever windows are open
## v2 refactor
The AI domain is the largest single area of the v2 refactor: the v1
renderer aiCore tree (formerly `src/renderer/src/aiCore/`, pre-v2 layout)
is fully deleted, with logic ported into `src/main/ai/`.
These reference docs are **self-contained** — they do not depend on the
throwaway `v2-refactor-temp/` tree. (The reviewer-facing change-cluster
narratives that live there are review logistics for the in-flight PR, and
are removed when the v2 AI refactor merges.)

View File

@@ -0,0 +1,115 @@
# Adapter Family
`adapterFamily` is the optional field on each `EndpointConfig` that picks
the `@ai-sdk/*` package implementing that endpoint's protocol. The runtime
resolver reads it; the catalog seeder and the v1→v2 migrator write it. The
schema declares it `optional`, and the resolver has a total fallback
(`openai-compatible`) for endpoints that omit it — so no write path is
obligated to set it.
## Identity stack
| Layer | Example | Role |
|---|---|---|
| `provider.id` | `minimax`, `silicon`, `my-relay` | User-facing identity, UI label, routing key |
| `endpointType` | `openai-chat-completions`, `anthropic-messages` | URL path template + protocol family |
| `adapterFamily` | `openai-compatible`, `anthropic`, `azure-responses` | Which `@ai-sdk/*` package implements this protocol |
Multi-endpoint relays (MiniMax, Silicon, AiHubMix) carry one
`adapterFamily` per endpoint under the same `provider.id` — different
endpoints on the same provider can route to different SDK packages.
## Runtime resolver
`src/main/ai/provider/endpoint.ts`:
```ts
export function resolveAiSdkProviderId(provider, endpointType) {
const adapterFamily = endpointType
? provider.endpointConfigs?.[endpointType]?.adapterFamily
: undefined
if (adapterFamily && adapterFamily in appProviderIds) {
return resolveProviderVariant(appProviderIds[adapterFamily], endpointType)
}
return appProviderIds['openai-compatible']
}
```
One signal, no heuristics. Tested with 54 cases in
`provider/__tests__/endpoint.test.ts`.
## Write paths
`adapterFamily` is a derived value computed at row-write time, never at
request time. One shared inference function lives at
`packages/provider-registry/src/registry-utils.ts`:
```ts
export function inferAdapterFamily(endpointType, catalogConfig?): string {
if (catalogConfig?.adapterFamily) return catalogConfig.adapterFamily
return ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY[endpointType] ?? 'openai-compatible'
}
```
### Endpoint-type defaults
| endpoint type | default adapter |
|---|---|
| `anthropic-messages` | `anthropic` |
| `google-generate-content` | `google` |
| `ollama-chat` / `ollama-generate` | `ollama` |
| `jina-rerank` | `jina-rerank` |
| `openai-responses` | `openai` |
| everything else | `openai-compatible` (terminal fallback) |
### Write paths
Only two paths write `adapterFamily`; both run in the **main** process at
row-write time:
1. **Catalog (new installs)**`packages/provider-registry/data/providers.json`
declares `adapterFamily` per endpoint per provider. The seeder copies
it through via `buildRuntimeEndpointConfigs`.
2. **v1 → v2 migration (existing users)**
`src/main/data/migration/v2/migrators/mappings/ProviderModelMappings.ts`
looks up the catalog by legacy id and, on a miss, calls
`inferAdapterFamily(endpointType)` for the endpoint-type default
(`ProviderModelMigrator.ts` carries a preset's `adapterFamily` forward
on merge). The `ANTHROPIC_MESSAGES` endpoint skips the legacy-type hint
because v1 custom anthropic relays carried `legacy.type='openai'` even
when the endpoint was anthropic-format.
The renderer's custom-provider form does **not** set `adapterFamily`:
`ProviderEditorDrawer.tsx` writes only `baseUrl` into the endpoint config,
leaving the field absent so the resolver's `openai-compatible` fallback
applies. `inferAdapterFamily` has **no renderer callers** — it is invoked
only by the migrator above.
## Schema
`src/shared/data/types/provider.ts::EndpointConfigSchema`:
```ts
EndpointConfigSchema = z.object({
baseUrl: z.string().optional(),
adapterFamily: z.string().optional(), // optional — resolver falls back to openai-compatible
// ... other endpoint-config fields
})
```
`packages/provider-registry/src/schemas/provider.ts::RegistryEndpointConfigSchema`
mirrors this for catalog entries.
## Tests
| Target | File | Cases |
|---|---|---|
| `inferAdapterFamily` | `packages/provider-registry/src/__tests__/registry-utils.test.ts` | 5 |
| Migrator backfill | `src/main/data/migration/v2/migrators/mappings/__tests__/ProviderModelMappings.test.ts` | 4 |
| Runtime resolver | `src/main/ai/provider/__tests__/endpoint.test.ts` | 54 |
| `buildRuntimeEndpointConfigs` | `packages/provider-registry/src/__tests__/registry-utils.test.ts` | 10 |
## Where to read more
- Runtime usage: [Provider Resolution](./provider-resolution.md)
- Catalog: `packages/provider-registry/data/providers.json`

View File

@@ -0,0 +1,116 @@
# Agent Loop
## What it is
`Agent` (`src/main/ai/runtime/aiSdk/Agent.ts`) wraps `@cherrystudio/ai-core`'s
`createAgent(...).stream()` (built on the AI SDK's `ToolLoopAgent`) with a
`composeHooks` pipeline that folds N
independent hook contributors (per-feature plugins, AiService analytics,
internal observers) into a single `AgentLoopHooks` object with deterministic
ordering, then bridges one streaming pass to a `ReadableStream<UIMessageChunk>`
with a stable id for the first emitted message.
The stream is **single-pass**: `Agent.stream` runs the AI SDK stream exactly
once and pipes it through. There is no mid-stream message injection — steering
a chat turn is handled upstream by abort-and-restart (see
[Stream Manager](./stream-manager.md#steering)).
`Agent` does not know about topics, IPC, persistence, or multi-model
fan-out. Those concerns live in the stream manager — see
[Stream Manager](./stream-manager.md).
## API
```ts
const agent = new Agent({
providerId, providerSettings, modelId,
plugins, tools, system, options,
hookParts, // RequestFeature contributions
messageId // stable id for the first emitted UIMessage
})
const stream: ReadableStream<UIMessageChunk> = agent.stream(initialMessages, signal)
// or (non-streaming; input is { prompt } | { messages })
const result = await agent.generate({ messages }, signal)
// internal observers can also register on the agent:
const dispose = agent.on('onStepFinish', step => { })
```
`stream()` and `generate()` share the underlying agent — only the AI SDK
call differs. Future `runToCompletion()` / `toTool()` are placeholders;
they don't ship in this PR.
## Hooks model
```ts
interface AgentLoopHooks {
onStart?: () => Promise<void> | void
prepareStep?: PrepareStepFunction // chained
onStepFinish?: (step) => Promise<void> | void // void-fan-out
onToolExecutionStart?: (event) => Promise<void> | void
onToolExecutionEnd?: (event) => Promise<void> | void
onFinish?: () => Promise<void> | void
onError?: (ctx) => 'retry' | 'abort'
}
```
Hook contributions come from three sources, all folded by `composeHooks`:
1. **Internal observers** (`Agent.on(key, fn)`) — `attachUsageObserver`
(injects `message-metadata` chunks carrying token usage).
2. **Feature contributions** (`hookParts` param) — each `RequestFeature`'s
`contributeHooks(scope)` (see [Params Pipeline](./params-pipeline.md)).
3. **Caller hooks**`AiService` adds the analytics hook only (token-usage
accounting via `onStepFinish` / `onFinish`). It does *not* contribute a
root-span/trace lifecycle hook — the OTel root span is owned by
`AiStreamManager.runExecutionLoop`.
Composition rules per hook key:
| key | rule |
|---|---|
| `onStart`, `onFinish`, `onStepFinish`, `onToolExecutionStart/End` | `chainVoid` — sequential `for`-loop await; per-hook throws logged and swallowed, chain continues |
| `prepareStep` | chained — each invocation receives the previous return value |
| `onError` | every handler invoked sequentially; any `'retry'` makes the result `'retry'`; default `abort` |
All void hooks share the same `chainVoid` helper in `composeHooks.ts`
there is no `Promise.allSettled` / parallel path.
Tool execution events (`onToolExecutionStart/End`) are emitted by a
wrapper around each tool's `execute`. No released AI SDK version brackets a
single tool's execution: v6 exposes call-level (`experimental_onToolCallStart`)
and input-level (`onInputStart` / `onInputDelta` / `onInputAvailable`) hooks, but
nothing around `execute` itself — so we wrap. A future SDK version may add
Agent-level execution hooks with the same shape, at which point the wrapper is
removed and hook signatures stay stable.
## Steering
There is no in-loop steering. `Agent.stream` makes a single AI SDK pass and
never folds a mid-flight follow-up into the running turn — doing so mutated
in-flight history and had no clean turn boundary. A new chat submission to a
live topic is handled one level up by the stream manager: the dispatcher
aborts the running turn, waits for it to persist as `paused`, and starts a
fresh one — see [Stream Manager → Steering](./stream-manager.md#steering).
Agent-session runtimes are different: they queue their own follow-ups on the
session's `pendingTurns` and interrupt between turns rather than restarting —
see [Agent Session Runtime](./agent-session-runtime.md#live-follow-up).
## Error and abort
- `signal.aborted` is honoured throughout; aborted streams settle with
the accumulated chunks already broadcast.
- Thrown errors are caught and routed through `onError`. Returning
`'retry'` is reserved for a future implementation — today the loop
logs and aborts.
- The writer is settled exactly once via the `then`/`catch` of the
internal IIFE — listeners never see a half-closed stream.
## Where to read more
- Code: `src/main/ai/runtime/aiSdk/`
- Tests: `src/main/ai/runtime/aiSdk/loop/__tests__/agentLoop.test.ts`
- Stream manager integration: [Stream Manager](./stream-manager.md)
- Hook contributors: [Params Pipeline](./params-pipeline.md)

View File

@@ -0,0 +1,194 @@
# Agent Session Runtime
## Purpose
Agent-session streams need a stable host for UI turns, persistence, live
follow-ups, interrupt, and recovery. The host must not know whether the
underlying agent uses a long-lived process, a websocket, one HTTP request
per turn, or Claude Code's SDK `query`.
The boundary is:
- `AgentSessionRuntimeService` owns Cherry's UI/session lifecycle.
- `AgentSessionRuntimeDriver` owns the concrete agent-session runtime lifecycle.
Claude Code is the first driver. Its `query`, warm query, SDK input
queue, and `resume` handling are driver internals.
## Ownership
| Owner | Responsibility |
|---|---|
| `AgentChatContextProvider` | Validates the agent session, persists the user row (plus a pending assistant row on a fresh turn), and either starts a turn or enqueues a follow-up through the runtime. |
| `AgentSessionRuntimeService` | Owns one runtime entry per session: current UI turn, pending UI queue, runtime connection, latest resume token, terminal listeners, persistence, and idle timer. |
| `AgentSessionRuntimeDriver` | Connects to one concrete agent implementation and exposes `send`, optional `interrupt`, `close`, and an event stream. |
| `AiStreamManager` | Keeps the normal topic stream contract: start a turn, attach a follow-up subscriber to a live turn, pause the current runtime turn, and start the next runtime turn. |
| `AiService.streamText()` | Routes `request.runtime.kind === 'agent-session'` to `AgentSessionRuntimeService.openTurnStream()` and rejects agent-session topics that do not carry runtime metadata. |
| `ClaudeCodeRuntimeDriver` | Converts Claude SDK messages into generic runtime events and maps opaque resume tokens to Claude SDK `resume`. |
## Fresh turn
1. Renderer sends `Ai_Stream_Open` for topic `agent-session:<sessionId>`.
2. `AgentChatContextProvider` validates the session:
- the session must have an agent and workspace;
- the workspace path must pass `assertClaudeCodeWorkspaceDirectory`;
- the agent type must have a registered runtime driver;
- the agent must have a model.
3. The provider atomically saves:
- a `user` message with the submitted parts;
- a pending `assistant` message with the selected model id.
4. The provider calls `AgentSessionRuntimeService.beginTurn(...)`.
5. `beginTurn()` returns:
- a runtime persistence listener;
- a runtime terminal listener;
- a trace flush listener for `agent-session:${sessionId}` history files;
- a `turnId`.
Follow-up messages are not queued here — they live on the session
entry's `pendingTurns`, appended by `enqueueUserMessage()`.
6. The prepared model request includes:
- `runtime: { kind: 'agent-session', sessionId, turnId }`;
- `messageId` set to the pending assistant row;
- seed `messages`: the user row plus the empty assistant row.
7. `AiStreamManager` starts the execution. `AiService.streamText()`
detects the runtime metadata and calls `openTurnStream()` instead of
building a generic `Agent`.
8. `openTurnStream()` ensures there is a runtime connection and admits
the turn by calling `connection.send({ message })`.
## Live follow-up
If the same topic already has a live stream, `AgentChatContextProvider`
does **not** create a new assistant placeholder and does **not** call
`beginTurn()` again. It persists the new user row, hands the message to
`AgentSessionRuntimeService.enqueueUserMessage(sessionId, message)`, and
returns a `PreparedDispatch` with `models: []` so `AiStreamManager.send()`
takes the **inject** path — which for agent sessions only upserts the new
subscriber onto the running stream (no message is injected into the
execution; chat's abort-and-restart does not apply here).
`enqueueUserMessage()` appends the message to the session entry's
`pendingTurns`, then acts on the current turn:
1. if the turn is already terminal (or absent) — schedules the next turn;
2. if the turn is mid-tool-call (`activeToolIds` non-empty) — leaves it
alone; the next turn is scheduled when the turn settles;
3. otherwise — requests an interrupt when safe (`connection.interrupt()`
if the driver supports it), which terminalizes the current UI turn.
Once the current turn is paused or terminal, `startNextTurn()` drains the
next message off `pendingTurns` and starts a fresh runtime turn (below).
This keeps the renderer protocol unchanged while each driver decides how
to interrupt its own runtime.
## Starting the next runtime turn
When a paused, aborted, or completed runtime turn still has queued
follow-ups, `AgentSessionRuntimeService.startNextTurn()`:
1. shifts the next user message off the session entry's `pendingTurns`;
2. saves a new pending assistant row;
3. creates a fresh `turnId`;
4. calls `AiStreamManager.startRuntimeTurn(...)` with:
- the same topic id and model id;
- `runtime: { kind: 'agent-session', sessionId, turnId }`;
- seed messages containing the user row and empty assistant row.
The runtime connection may stay on the entry. What that means is driver
specific: Claude Code keeps its SDK query/input queue, while another
driver could keep a websocket or reconnect per turn.
## Resume token persistence
Drivers may emit:
```ts
{ type: 'resume-token'; token: string }
```
The host treats the value as opaque. It stores it as
`entry.lastResumeToken` and passes `runtimeResumeToken` to
`AgentSessionMessageBackend`, so the final assistant row receives the
latest resume token at terminal time.
This also covers error turns: if a driver emitted a resume token and then
failed, the assistant error row still records that token so the next
connection can recover from the newest driver-known state.
User rows do not need a resume token. The durable recovery anchor is the
latest assistant row with `runtimeResumeToken`.
For Claude Code, the resume token is the SDK `session_id`. The driver
maps it to `options.resume`. This is separate from the SDK's file
checkpointing / `rewindFiles()` feature, which uses user-message UUIDs
to restore files.
## Claude Code driver
Normal multi-turn chat does not use `continue: true` and does not rely
on cwd-based session discovery.
When `ClaudeCodeRuntimeDriver.connect()` needs to create a query, it
asks `buildClaudeCodeQueryRequestForAgentSession(sessionId, resumeToken)`.
The builder uses the first available value:
1. explicit resume token from the host;
2. latest persisted agent-session resume token from
`agentSessionMessageService.getLastRuntimeResumeToken(session.id)`;
3. no resume id for a brand-new SDK session.
The query may come from `ClaudeCodeWarmQueryManager.consume(...)` if a
prewarmed query is available. Otherwise the driver starts a new SDK
query with `createClaudeQuery({ prompt: driverSdkInputQueue, options })`.
The driver converts Claude SDK messages into runtime events:
- `stream_event` / assistant/user messages -> `chunk`;
- `system/init` -> `resume-token`;
- `result` -> `resume-token` and `turn-complete`;
- thrown errors -> `error`.
## Idle and shutdown
After a turn reaches terminal state, the runtime entry becomes `idle`.
For a short idle window it keeps:
- the runtime connection, if it is still alive;
- `lastResumeToken`;
- the session entry's `pendingTurns`.
If a new turn arrives during that window, `beginTurn()` reuses the same
entry and only swaps the current UI turn plus the UI pending queue.
When the idle timer expires, the runtime closes the entry:
- clears `pendingTurns`;
- closes the runtime connection;
- prewarms Claude Code when a latest resume token is known.
Service stop and destroy close all runtime entries.
## Removed old path
Claude Code is not a normal provider extension anymore:
- no `createClaudeCode`;
- no `ClaudeCodeLanguageModel`;
- no `ClaudeCodeProviderSettings`;
- no `injectedMessageSource` in provider settings;
- no `providerToAiSdkConfig(..., { runtimeResumeToken })` branch.
Any `agent-session:*` stream that reaches `AiService.streamText()`
without runtime metadata is rejected. That fail-fast rule prevents a
regression back to one CLI process per turn without the long-lived SDK
input queue inside the Claude Code driver.
## Verification
Focused tests:
- `src/main/ai/streamManager/context/__tests__/AgentChatContextProvider.test.ts`
- `src/main/ai/agentSession/__tests__/AgentSessionRuntimeService.test.ts`
- `src/main/ai/runtime/claudeCode/__tests__/ClaudeCodeRuntimeDriver.test.ts`
- `src/main/ai/__tests__/AiService.test.ts`
- `src/main/ai/runtime/claudeCode/__tests__/streamAdapter.test.ts`
- `src/main/ai/runtime/claudeCode/__tests__/ClaudeCodeWarmQueryManager.test.ts`

View File

@@ -0,0 +1,180 @@
# Core Architecture
End-to-end view of how a Cherry chat turn moves from user input to LLM
response and back to UI, with pointers to the focused references for
each subsystem.
## Layered view
```
┌──────────────────────────────────────────────────────────────────────┐
│ Renderer │
│ │
│ useChat({ id: topicId, transport: IpcChatTransport }) │
│ ├─ sendMessages → window.api.ai.streamOpen │
│ ├─ reconnectToStream → window.api.ai.streamAttach │
│ └─ abort signal → window.api.ai.streamAbort │
│ │
│ History: useQuery('/topics/:id/messages') → DataApi │
│ Topic-level state: useTopicStreamStatus → shared cache │
│ Approval bridge: useToolApprovalBridge → window.api.ai.toolApproval│
└──────────────────────────────────────────────────────────────────────┘
↕ IPC (keyed by topicId)
┌──────────────────────────────────────────────────────────────────────┐
│ Main │
│ │
│ AiStreamManager (lifecycle service) — registers in onInit: │
│ ├─ ipcHandle('Ai_Stream_Open', → dispatchStreamRequest) │
│ ├─ ipcHandle('Ai_Stream_Attach', → this.attach) │
│ ├─ ipcHandle('Ai_Stream_Detach', → this.detach) │
│ └─ ipcHandle('Ai_Stream_Abort', → this.abort) │
│ │
│ AiService (lifecycle service) — registers: │
│ ├─ ipcHandle('Ai_ToolApproval_Respond', <inline handler>) │
│ └─ ipcHandle('Ai_GenerateText' / 'Ai_Translate_Open' / …) │
│ │
│ dispatch (src/main/ai/streamManager/context/dispatch.ts) │
│ pick ChatContextProvider → prepareDispatch → manager.send(...) │
│ │
│ AiStreamManager │
│ activeStreams: Map<topicId, ActiveStream> │
│ listeners + executions │
│ runs N StreamExecution loops, fan-out per chunk to listeners │
│ │
│ runExecutionLoop (AiStreamManager) → AiService.streamText(req,signal)│
│ buildAgentParams: registry.selectActive + applyDeferExposition │
│ new Agent({tools, hookParts}) — composeHooks runs inside Agent │
│ → agent.stream(messages, signal) │
│ pipeStreamLoop tees: │
│ • broadcast → WebContents / SSE / channel-adapter / persistence │
│ • readUIMessageStream → CherryUIMessage snapshot │
│ │
│ Terminal listeners: │
│ PersistenceListener → MessageService / TemporaryChat / Translation
│ WebContentsListener → wc.send(Ai_StreamDone) │
│ ChannelAdapterListener → adapter.onStreamComplete │
│ SseListener → res.write('[DONE]') │
└──────────────────────────────────────────────────────────────────────┘
@ai-sdk/* package
LLM provider API
```
## Sequence: a fresh chat turn
1. User hits send. `useChat.sendMessages` calls `IpcChatTransport.sendMessages`.
2. Transport packages `AiStreamOpenRequest`, dispatches via
`streamDispatchCoordinator` over IPC `Ai_Stream_Open`.
3. `AiStreamManager`'s `Ai_Stream_Open` handler (registered in `onInit`)
wraps the sender in a `WebContentsListener` and calls
`dispatchStreamRequest(manager, subscriber, request)`.
4. `dispatchStreamRequest` picks the first `ChatContextProvider` whose
`canHandle(topicId)` matches and asks it to `prepareDispatch`.
5. The provider resolves models, persists the user message (chat) or skips
persistence (temporary / translate), creates `PersistenceListener` per
execution, returns `PreparedDispatch`.
6. `dispatch` reconciles any live stream, then calls `manager.send(input)`:
- **chat resubmit** (topic already streaming): `manager.abortAndAwait(topicId)`
aborts the running executions and waits for their loops to settle (the
partial persists as `paused`) before `send()` **starts** a fresh turn —
steering is abort-and-restart, not mid-turn injection.
- **agent-session follow-up**: the stream is left running and `send()`
**injects** — it upserts `listeners` onto the running stream, `models`
ignored (the message was already enqueued on the session's `pendingTurns`).
- **no live stream**: `send()` **starts** — evict any grace-period stream,
create an `ActiveStream`, launch one `StreamExecution` per model.
7. For each `StreamExecution`, `AiStreamManager`'s private `runExecutionLoop`
calls `AiService.streamText(request, signal)`, which builds params
(`buildAgentParamsFor → buildAgentParams`: `registry.selectActive` +
`applyDeferExposition` + per-feature hooks), constructs an `Agent`
(`composeHooks` folds observers + caller + features inside `Agent`), and
calls `agent.stream(messages, signal)` — which opens AI SDK's stream and
yields `UIMessageChunk`s. Agent-session runtime requests skip the generic
agent loop here: `AiService.streamText()` calls
`AgentSessionRuntimeService.openTurnStream()` so the registered driver
can own the concrete agent runtime.
8. `pipeStreamLoop` reads the chunk stream once, tees: broadcast to
listeners, accumulate via `readUIMessageStream`.
9. On terminal (`done` / `error` / `aborted` / `awaiting-approval`):
- `PersistenceListener` writes the final assistant message.
- `WebContentsListener` broadcasts `Ai_StreamDone` to subscribed windows.
- Shared-cache `topic.stream.statuses.<topicId>` flips to the terminal status.
10. Renderer's `useQuery('/topics/:id/messages')` revalidates; the
optimistic overlay is disposed.
## Sequence: tool approval pause + resume
1. AI SDK calls `tool.execute(args, toolCallContext)`. The wrapper sees
`needsApproval(args)` returns true and the assistant's auto-approve
policy says "ask". It writes an `approval-requested` part on the
accumulated message and holds the promise.
2. Manager flips status to `awaiting-approval` on the shared cache.
3. Renderer's `useTopicAwaitingApproval(topicId)` returns true; the UI
shows the approval card.
4. User decides → `useToolApprovalBridge``Ai_ToolApproval_Respond`.
5. Main applies the decision to the anchor row, resumes the stream
(Claude-Agent: resolves the `canUseTool` promise; MCP: dispatches a
`continue-conversation` so the existing stream rebroadcasts).
6. Status flips back to `streaming`; UI hides the card.
See [Tool Approval](./tool-approval.md) for invariants and the
overlay-vs-persist conditional write.
## Key subsystems
| Subsystem | Reference |
|---|---|
| Active-stream registry, listeners, persistence backends, reconnect, abort, grace-period eviction | [Stream Manager](./stream-manager.md) |
| Claude Code agent-session long-lived runtime, SDK input queue, resume fallback | [Agent Session Runtime](./agent-session-runtime.md) |
| `Agent.stream` single-pass loop, hooks model, error/abort | [Agent Loop](./agent-loop.md) |
| `buildAgentParams`, `RequestFeature` composition, `INTERNAL_FEATURES` order | [Params Pipeline](./params-pipeline.md) |
| Tool registry, MCP sync, meta-tools (`tool_search` / `tool_inspect` / `tool_invoke` / `tool_exec`), defer exposition | [Tool Registry](./tool-registry.md) |
| `Provider.endpointConfigs`, `endpointType` resolution, variant suffixes, custom providers | [Provider Resolution](./provider-resolution.md) |
| `adapterFamily` field, runtime resolver, write paths (catalog / migrator) | [Adapter Family](./adapter-family.md) |
| OTel span tree, `AdapterTracer`, `AiSdkSpanAdapter`, dev-tools view | [Observability](./observability.md) |
| `IpcChatTransport`, dispatch coordinator, per-execution demux | [IPC Transport](./ipc-transport.md) |
| Approval flow, Main-as-writer invariant, persistent decisions | [Tool Approval](./tool-approval.md) |
## Invariants
- **Topic-level addressing.** Every IPC, broadcast, and shared-cache
entry is keyed by `topicId`. A topic has at most one active stream;
subscribers are equal — there is no "owner" window.
- **Main owns persistence.** Renderer closing or crashing does not abort
the stream or lose data. `PersistenceListener` writes on terminal
regardless of subscriber state.
- **Main owns approval state.** The renderer is never a writer.
- **Adapter family is per-endpoint.** Multi-endpoint relays may use
different `@ai-sdk/*` packages on different endpoints under the same
`provider.id`.
- **`tools/applies` predicates are pure.** They run on every
`selectActive` pass; side effects there break tool selection
determinism.
- **Features must not mutate `RequestScope`.** It is shared across all
features for a single request.
## Code map
```
src/main/ai/
├── AiService.ts ← lifecycle owner, IPC entry (generate / translate / approval)
├── runtime/ ← execution backends: runtime/aiSdk (Agent + params), runtime/claudeCode
├── agentSession/ ← agent-session topic host
├── agents/ ← AgentJobsService, AgentTaskJobHandler, runAgentTask, cherryclaw
├── channels/ ← ChannelManager + IM adapters (discord/feishu/qq/slack/telegram/wechat) + security/
├── streamManager/ ← AiStreamManager, listeners, persistence (registers the stream IPC)
├── provider/ ← provider config, endpoint resolution, custom providers
├── mcp/ ← McpRuntimeService / McpCatalogService, oauth, built-in servers
├── skills/ ← SkillService, SkillInstaller
├── tools/ ← unified tool registry (adapters/aiSdk + adapters/claudeCode)
├── observability/ ← AI trace adapters, local projection, sinks
├── messages/ ← UI part → AI SDK part conversion
├── types/ ← AppProviderId, merged types, request types
└── utils/ ← reasoning / model parameters / options / websearch
src/renderer/transport/ ← IpcChatTransport, dispatch coordinator
src/renderer/hooks/ ← useChatWithHistory, useToolApprovalBridge, useTopicStreamStatus
packages/aiCore/ ← @cherrystudio/ai-core (Agent + plugins + provider extensions)
packages/provider-registry/ ← provider catalog, registry-utils (adapterFamily inference)
```

View File

@@ -0,0 +1,223 @@
# Execution Overlay
The renderer-side counterpart of Main's `pipeStreamLoop`. Both sides
use the **same pure assembler**
[AI SDK's `readUIMessageStream`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/read-ui-message-stream) —
to turn the chunk stream into a `CherryUIMessage`. Main writes the
result to disk; the renderer paints it onto the chat surface as an
overlay above the SWR-backed history.
## Why the same merge function on both sides
`UIMessageChunk` assembly is non-trivial: text deltas merge by `id`,
reasoning blocks have their own start/delta/end, tool calls go through
`tool-input-start` / `tool-input-delta` / `tool-input-available` /
`tool-output-available`, dynamic data parts merge by key, multi-step
turns carry step boundaries. Re-implementing any of this on the
renderer would mean a second source of truth that *had* to track AI SDK
upstream, with two ways to disagree about partial state.
Running the same `readUIMessageStream` on the same `UIMessageChunk`
stream — once on Main (writing to `exec.finalMessage`), once on the
renderer (driving the overlay) — guarantees structural agreement.
What persists is exactly what the user saw streaming.
```
Main: pipeStreamLoop(stream)
tee()
├─ branch A → broadcast to listeners → WebContentsListener → IPC chunks
└─ branch B → readUIMessageStream → exec.finalMessage (writes to DB)
│ (DB write)
Renderer: TopicStreamSubscription ┌──── readUIMessageStream → snapshot
│ │ │ ▲
│ ▼ │ │
│ routes chunks by │ fed by branch stream
│ executionId into │
│ per-execution branches ─────┘
branch ReadableStream → useExecutionOverlay (per execution)
```
## TopicStreamSubscription
`src/renderer/transport/TopicStreamSubscription.ts`. A renderer
class that owns:
- **One IPC attach per topic.** `attach` is ref-counted — every
execution that calls `register(executionId)` increments the count;
the last `unregister` triggers `detach` (deferred one microtask so a
transient `activeExecutions` flicker doesn't detach-then-reattach).
- **Per-execution demux.** Each `register(executionId)` returns a
`ReadableStream<UIMessageChunk>` that contains only the chunks tagged
with that `executionId` by Main. Multi-model parallel responses each
get their own branch.
- **Synchronous controller creation.** The branch's
`ReadableStreamDefaultController` is created during the
`new ReadableStream({ start })` call (synchronous), so chunks that
arrived between `register` and the reader's first `read()` are
already buffered in the stream's internal queue — late readers never
miss replayed chunks.
- **Terminal demux.** `Ai_StreamDone` / `Ai_StreamError` close the
matching branch and fan out an `ExecutionTerminal` (`{ isAbort,
isError }`) to listeners; if the payload carries `isTopicDone` or no
`executionId`, every branch terminates together.
### Cancellation layering — do not conflate
| Layer | Owner | Action |
|---|---|---|
| Renderer-local subscription | `TopicStreamSubscription.unregister` / `dispose` | Closes the branch reader, drops listener ref count; Main keeps generating |
| Generation abort | Main (via `useChatWithHistory.stop` → Chat → `Ai_Stream_Abort`) | Stops the LLM |
`TopicStreamSubscription` NEVER aborts the LLM. Closing all branches
is the renderer equivalent of `streamDetach` — Main keeps streaming,
other windows keep observing.
### Defensive routing
A chunk without `executionId` is unexpected — Main always tags chat
chunks. As a defensive fallback, if exactly one branch is registered
the chunk routes there; otherwise it's dropped with a warning.
## useExecutionOverlay
`src/renderer/hooks/useExecutionOverlay.ts`. The per-execution
overlay, built on `useTopicStreamSubscription`.
```ts
const { overlay, liveAssistants, disposeOverlay, reset } = useExecutionOverlay(
topicId,
activeExecutions, // ActiveExecution[] from useTopicStreamStatus
uiMessages, // current DB snapshot
{ onFinish }
)
```
### One reader per turn, zero cross-turn state
Each execution gets a **one-shot `readUIMessageStream` reader** per
turn, not a stateful AI SDK `Chat`. A `Chat` carries
`state.messages` across turns; reusing it made a new turn resume from
the previous turn's finished assistant ("previous answer + new
stream"). A fresh reader per turn structurally cannot pollute.
### The seed rule (continue-safe)
```ts
function pickSeed(uiMessages, anchorMessageId): CherryUIMessage | undefined {
if (!anchorMessageId) return undefined
const found = uiMessages.find((m) => m.id === anchorMessageId)
if (!found) return { id: anchorMessageId, role: 'assistant', parts: [] }
// `readUIMessageStream` mutates `message.parts` in place, and `found` is the live
// SWR-derived row — clone the parts so the reader only ever writes to a throwaway.
return { ...found, parts: structuredClone(found.parts ?? []) }
}
```
The reader is seeded with the message whose id is the execution's
`anchorMessageId`, taken from the **current DB truth** at reader-start
time. Two cases:
- **Fresh placeholder** — the SQLite row has empty parts; the seed is
effectively empty and the reader builds the message from scratch.
- **Tool-approval / continue-conversation** — the row already carries
the prior assistant parts (including the unresolved `tool-input` part
the approval was on). A streamed `tool-output` chunk then merges
cleanly onto its matching `tool-input` because they share the same
`toolCallId`.
The seed is re-derived from DB on every reader start; it never carries
across turns, and its `parts` are cloned so the reader's in-place mutation
never touches the SWR row. Combined with the fresh reader, this is the
**structural** anti-pollution guarantee — not "force empty parts" or "diff
against last frame".
### Lifecycle
1. **Topic switch** — every reader is cancelled, every branch
unregistered, `snapshots` cleared. `prevTopicRef` is checked in the
render body so the cleanup runs synchronously before the new topic's
readers start.
2. **`activeExecutions` change** — diff against the current reader
map: cancel + unregister executions no longer in the active list;
for newly-active executions, register a branch, clear any retained
prior snapshot, kick a new reader.
3. **Terminal** — the branch is closed by `TopicStreamSubscription`;
the reader's `for await` exits. The `onFinish(executionId, event)`
callback fires with the final snapshot + `{ isAbort, isError }`.
4. **Unmount** — every reader is cancelled, every branch unregistered.
### Overlay teardown is monotonic
`disposeOverlay(messageId)` drops exactly one snapshot entry. The chat
shell wires this so the overlay is released **only after** the DB
refresh promise resolves (see `.finally(() => disposeOverlay(...))` in
`V2ChatContent`). That ordering eliminates the visible flash between
"streaming overlay" and "persisted parts": the SWR cache holds the
authoritative row before the overlay disappears.
The renderer never writes streamed parts to SWR — writing them would
race the DB-authoritative refresh and cause flicker.
### Why retained snapshots after terminal
The hook keeps the final snapshot in `snapshots` until one of:
- the same execution restarts (next turn clears it),
- the caller calls `disposeOverlay(messageId)` (post-persist handoff),
- the caller calls `reset()` (e.g. quick-assistant clear),
- the topic switches (effect clears all snapshots).
That retention lets consumers read the final frame for the brief window
between stream-end and DB-refresh-complete without going through SWR.
## React binding
`useTopicStreamSubscription(topicId)` is the React wrapper:
- Lazy-init per `topicId` (idiom mirrors `useState(() => ...)`).
- Disposed on unmount or topic switch — drops the Main listener and
closes every branch.
Each mounted topic gets one `TopicStreamSubscription` instance, shared
by every consumer in that React tree (today: `useExecutionOverlay`).
## Code map
```
src/renderer/transport/TopicStreamSubscription.ts ← class
src/renderer/hooks/useTopicStreamSubscription.ts ← React binding
src/renderer/hooks/useExecutionOverlay.ts ← per-execution readers + overlay
src/renderer/pages/home/V2ChatContent.tsx ← consumer + dispose-after-refresh
```
## Invariants reviewers should check
1. **Same merge function on both sides.** Any code that re-implements
chunk → message assembly on the renderer (instead of feeding
`readUIMessageStream`) is wrong — that's where Main and renderer
will diverge first.
2. **One reader per turn.** No reader should be reused across
`activeExecutions` transitions. Reusing one is what the v1 `Chat`
bug was; the structural fix is structural.
3. **Seed from current DB.** `pickSeed` reads `uiMessagesRef.current`
at reader-start time. Stashing the seed on first mount and reusing
it across turns would defeat the continue-conversation case.
4. **Overlay disposed after DB refresh.** Any
`disposeOverlay(messageId)` call that runs **before** the DB
revalidation promise resolves is a flicker bug.
5. **`TopicStreamSubscription` never aborts.** It only detaches.
Anything in this layer that calls `Ai_Stream_Abort` is in the
wrong place — abort belongs to `useChatWithHistory.stop`.
6. **Ref-counted attach.** A new attach must NOT fire when another
execution is already registered for the same topic. A new detach
must NOT fire while any execution still has a branch.
## Where to read more
- Main-side accumulator: [Stream Manager — `pipeStreamLoop`](./stream-manager.md#execution-loop--runexecutionloop--pipestreamloop)
- IPC envelope: [IPC Transport](./ipc-transport.md)
- Topic status / approval-anchor surfacing: [Tool Approval](./tool-approval.md)
- AI SDK upstream: [`readUIMessageStream` reference](https://ai-sdk.dev/docs/reference/ai-sdk-ui/read-ui-message-stream)

View File

@@ -143,7 +143,7 @@ Empty / `undefined` / `null` values are dropped here — the server applies its
own default; no client-side defaults are invented. The `'auto'` sentinel is
**not** dropped at this stage: it's carried through and resolved to "omit the
field" one stage later by the emitters (e.g. `toDashScopeSize` /
`AiProvider.resolveImageSize`).
`resolveSizeParameter` in `dashscopeTransport.ts`).
### 2. Transport routing hint (`paintingPipeline`)
@@ -154,7 +154,7 @@ field" one stage later by the emitters (e.g. `toDashScopeSize` /
### 3. providerOptions emitters (`buildImageProviderOptions`)
[`src/renderer/aiCore/utils/imageOptions.ts`](../../../src/renderer/aiCore/utils/imageOptions.ts) is a **table of per-provider emitters** that map canonical params to each vendor's real wire field names and bag key:
[`src/main/ai/utils/imageOptions.ts`](../../../src/main/ai/utils/imageOptions.ts) is a **table of per-provider emitters** that map canonical params to each vendor's real wire field names and bag key:
```
EMITTERS: Record<providerId, Emitter> // unlisted ids → diffusion fallback
@@ -170,10 +170,10 @@ nesting, enum casing. Nowhere else.
### 4. The model itself
[`AiProvider.modernGeneratePaintingImage`](../../../src/renderer/aiCore/AiProvider.ts) hands `aiSdkParams` + `providerOptions` to the resolved image model. The model is one of two kinds, decided by the provider factory:
[`AiService.generateImage`](../../../src/main/ai/AiService.ts) (reached via the `Ai_GenerateImage` IPC) hands `aiSdkParams` + `providerOptions` to the resolved image model. The model is one of two kinds, decided by the provider factory:
- **Native AI SDK image model** — `OpenAIImageModel`, `@ai-sdk/google` `.image()`, `OpenAICompatibleImageModel`. Spreads `providerOptions[key]` into the request body.
- **Custom `ImageGenerationTransport`** — for async submit→poll vendors or non-OpenAI wire shapes (DashScope, PPIO, DMXAPI's Doubao/Wan/async-Qwen families). See [`src/renderer/aiCore/provider/custom/imageGenerationModel.ts`](../../../src/renderer/aiCore/provider/custom/imageGenerationModel.ts); each vendor's transport lives beside its provider in a per-vendor folder (e.g. [`dmxapi/dmxapiTransport.ts`](../../../src/renderer/aiCore/provider/custom/dmxapi/dmxapiTransport.ts)), with shared helpers in [`transportUtils.ts`](../../../src/renderer/aiCore/provider/custom/transportUtils.ts). Multi-backend gateways (DMXAPI) dispatch by a `{match, family}` table on the model id — see [`dmxapi/dmxapiProvider.ts`](../../../src/renderer/aiCore/provider/custom/dmxapi/dmxapiProvider.ts).
- **Custom `ImageGenerationTransport`** — for async submit→poll vendors or non-OpenAI wire shapes (DashScope, PPIO, DMXAPI's Doubao/Wan/async-Qwen families). See [`src/main/ai/provider/custom/imageGenerationModel.ts`](../../../src/main/ai/provider/custom/imageGenerationModel.ts); each vendor's transport lives beside its provider in a per-vendor folder (e.g. [`dmxapi/dmxapiTransport.ts`](../../../src/main/ai/provider/custom/dmxapi/dmxapiTransport.ts)), with shared helpers in [`transportUtils.ts`](../../../src/main/ai/provider/custom/transportUtils.ts). Multi-backend gateways (DMXAPI) dispatch by a `{match, family}` table on the model id — see [`dmxapi/dmxapiProvider.ts`](../../../src/main/ai/provider/custom/dmxapi/dmxapiProvider.ts).
---
@@ -213,8 +213,8 @@ nesting, enum casing. Nowhere else.
| Default population on switch | `src/renderer/pages/paintings/utils/computeModelFieldReset.ts` |
| Param partition | `src/renderer/pages/paintings/model/canonicalGenerate.ts` |
| Transport hint + requirePrompt | `src/renderer/pages/paintings/model/paintingPipeline.ts` |
| providerOptions emitters | `src/renderer/aiCore/utils/imageOptions.ts` |
| Custom transport wrapper | `src/renderer/aiCore/provider/custom/imageGenerationModel.ts` |
| Vendor provider + transport | `src/renderer/aiCore/provider/custom/<vendor>/{<vendor>Provider,<vendor>Transport}.ts` |
| Shared transport helpers | `src/renderer/aiCore/provider/custom/transportUtils.ts` |
| providerOptions emitters | `src/main/ai/utils/imageOptions.ts` |
| Custom transport wrapper | `src/main/ai/provider/custom/imageGenerationModel.ts` |
| Vendor provider + transport | `src/main/ai/provider/custom/<vendor>/{<vendor>Provider,<vendor>Transport}.ts` |
| Shared transport helpers | `src/main/ai/provider/custom/transportUtils.ts` |
```

View File

@@ -0,0 +1,98 @@
# IPC Transport
## What it is
`IpcChatTransport`
(`src/renderer/transport/IpcChatTransport.ts`) implements AI SDK's
`ChatTransport<CherryUIMessage>` over Electron IPC. The renderer feeds
it into `useChat({ id: topicId, transport: ... })`. The `ChatTransport`
interface has only two methods — `sendMessages` / `reconnectToStream`;
the transport relays each over `window.api.ai.stream*` to Main's
`AiStreamManager`. `cancel` is **not** a transport method: it is the
`cancel` callback of the `ReadableStream` that `sendMessages` returns
(AI SDK invokes it on unmount/disposal), and abort is driven by the
request's `abortSignal`.
```
useChat({ id: topicId, transport: new IpcChatTransport(defaultBody) })
│ transport methods
├─ sendMessages → window.api.ai.streamOpen (Ai_Stream_Open)
├─ reconnectToStream → window.api.ai.streamAttach (Ai_Stream_Attach)
│ returned-stream / signal callbacks
├─ stream cancel() → window.api.ai.streamDetach (Ai_Stream_Detach)
└─ request abort signal → window.api.ai.streamAbort (Ai_Stream_Abort)
```
**Detach ≠ abort.** `cancel()` (e.g. unmount/disposal) calls `streamDetach`:
it drops *this* subscriber while Main keeps generating and persists the
result. Stopping generation is a separate path — the request's `abortSignal`
firing calls `streamAbort`. Conflating the two would resurrect the v1
"unmount → cancel → upstream abort → lost reply" bug class.
Per-topic chunks arrive via `onStreamChunk` listeners filtered by
`topicId`.
## Triggers
`sendMessages` distinguishes two triggers:
| Trigger | What it does |
|---|---|
| `submit-message` | Includes `userMessageParts` (the latest message) so Main persists it |
| `regenerate-message` | Sends `parentAnchorId` only; Main re-runs from the existing parent |
Cherry's transport never derives `continue-conversation` from
message-state introspection. Approval-driven resumption goes through the
explicit `Ai_ToolApproval_Respond` IPC handled by
[`useToolApprovalBridge`](./tool-approval.md).
## Dispatch coordinator
`streamDispatchCoordinator` (`src/renderer/transport/streamDispatchCoordinator.ts`)
sits between the transport and the IPC call so the `Ai_Stream_Open` ack
(`userMessageId`, placeholder ids, executionIds) is observable to callers
that need to join optimistic UI bubbles, rather than being thrown away by
AI SDK's transport interface.
It does **not** serialize sends — there is no single-in-flight guard in the
coordinator. Concurrency for a topic is arbitrated on the Main side: a chat
resubmit to a live topic is aborted-and-restarted by `dispatch`
(`AiStreamManager.abortAndAwait`), while an agent-session follow-up attaches
to the running stream.
## Per-execution demux
The chunk stream from Main is keyed by `(topicId, executionId)`.
`TopicStreamSubscription`
(`src/renderer/transport/TopicStreamSubscription.ts`) owns the
topic-level `streamAttach` / `streamDetach` with ref-counted lifecycle
and demuxes chunks into per-execution branch `ReadableStream`s, so
multi-model parallel responses render as separate AI SDK messages on
the same topic. `useExecutionOverlay` consumes each branch through
`readUIMessageStream` — the same accumulator Main runs in
`pipeStreamLoop`, so the renderer overlay and the persisted message
are structurally identical.
See [Execution Overlay](./execution-overlay.md) for the merge-function
symmetry, seed rule, cancellation layering, and lifecycle.
## Topic-level subscription
`useTopicStreamStatus(topicId)` reads
`topic.stream.statuses.<topicId>` from the shared cache. The cache is
the cross-window source of truth for:
- `pending` / `streaming` / `awaiting-approval` / `done` / `error` / `aborted`
- broadcast-completion anchor ids
`classifyTurn(status)` decodes the status into the `TurnStateFlags`
predicates the UI consumes (`isStreamLive`, `isTurnActive`,
`isAwaitingApproval`, `isTerminal`).
## Where to read more
- Code: `src/renderer/transport/`
- Hook glue: `src/renderer/hooks/useChatWithHistory.ts`
- Per-execution overlay (renderer assembler): [Execution Overlay](./execution-overlay.md)
- Approval bridge: [Tool Approval](./tool-approval.md)
- Main side: [Stream Manager](./stream-manager.md)

View File

@@ -0,0 +1,120 @@
# Observability
The `src/main/ai/observability/` subsystem: OTel tracing, the local span
projection (trace viewer), and the sink registry. "Trace / telemetry" is the
user-facing surface; this doc covers the whole subsystem.
## What's instrumented
Every AI SDK call run through Cherry produces an OpenTelemetry span
tree:
```
chat.turn (root, created by context provider)
├── ai.streamText (AI SDK auto)
│ ├── ai.streamText.doStream (AI SDK auto)
│ ├── ai.toolCall (per tool invocation) (AI SDK auto)
│ └── ai.streamText.<step> (AI SDK auto)
└── attributes: topicId, modelName, … (set by AiTurnTrace / AdapterTracer)
```
AI SDK's `experimental_telemetry` produces the inner spans; Cherry owns
the root span through `AiTurnTrace` so it lands in the same observability
path without going through the AI SDK adapter.
The main-process observability boundary is `src/main/ai/observability`:
- `core/` creates Cherry-owned turn roots and common `cs.*` attributes.
- `adapters/aiSdk/` interprets AI SDK child spans.
- `adapters/claudeCode/` interprets Claude Code OTLP spans and logs.
- `cache/` keeps the trace-window projection and JSONL-compatible history.
- `sinks/` defines the extension point for local and future external export.
## Local history flush
`Message.traceId` is persisted with the assistant message row, but the span
tree is first collected in the main-process `SpanCacheService` memory store.
The durable history file is written by the stream terminal path, not by the
trace UI:
- `PersistentChatContextProvider` attaches a `TraceFlushListener` to normal
chat turns.
- `AgentSessionRuntimeService` attaches the same listener to
`agent-session:${sessionId}` turns, including queued follow-up turns.
- On the topic-level terminal event (`done`, `paused`, or `error`),
`TraceFlushListener` calls `SpanCacheService.saveSpans(topicId)`.
- Flush errors are logged as warnings and do not affect message completion.
The trace pane and trace window are viewers only. They read spans through
`SpanCacheService.getSpans(...)`, which falls back to the JSONL history file
when the in-memory store has already been cleared.
## AdapterTracer
`src/main/ai/observability/adapters/aiSdk/adapterTracer.ts` wraps the OTel `Tracer` returned
by the global provider. On every `startSpan` / `startActiveSpan` it:
1. Patches `span.end()` to also call `AiSdkSpanAdapter.convertToSpanEntity(...)`
and hand the result to the observability sink registry.
2. Stamps `trace.topicId` and `trace.modelName` so spans are queryable
per topic in the dev-tools UI.
`AdapterTracer` is intentionally only for AI SDK child spans:
- `buildTelemetry` (`runtime/aiSdk/params/buildTelemetry.ts`) — passed to AI
SDK as `experimental_telemetry.tracer`. Captures every AI SDK auto-span.
Returns `undefined` (no telemetry, no tracer) when there is no `topicId`
or developer mode is off — see below.
## AiSdkSpanAdapter
`src/main/ai/observability/adapters/aiSdk/aiSdkSpanAdapter.ts` converts an OTel span into the
shape the dev-tools UI consumes:
- Reads span name, attributes, events, status, links.
- Recovers AI SDK's hierarchical attribute conventions:
`ai.xxx` is a level, `ai.xxx.yyy` is a sub-level under it.
- Normalises usage attributes across the base and LLM spans: input from
`ai.usage.promptTokens` / `gen_ai.usage.input_tokens`, output from
`ai.usage.completionTokens` / `gen_ai.usage.output_tokens`. (There is no
reasoning-token extraction.)
Claude Code Agent SDK spans do not go through `AiSdkSpanAdapter`; they are
converted by `src/main/ai/observability/adapters/claudeCode/ClaudeCodeOtlpAdapter.ts`.
## Sensitive data capture & redaction
> Cross-referenced from `ClaudeCodeTraceBridgeService.prepareTrace`.
The Claude Code OTLP bridge runs **only when developer mode is enabled**. When
it does, it intentionally turns on verbose Claude Code telemetry:
- `OTEL_LOG_USER_PROMPTS` — user prompt text
- `OTEL_LOG_TOOL_DETAILS` / `OTEL_LOG_TOOL_CONTENT` — tool calls and their content
- `OTEL_LOG_RAW_API_BODIES` — raw API request/response bodies
These payloads land in span attributes that `SpanCacheService` persists as
**plaintext JSONL trace files on disk**, so a trace can contain secrets
(authorization headers, API keys embedded in raw bodies) alongside the prompt
and tool content.
**Redaction is deliberately not done.** Stripping secrets would mean parsing
arbitrary OTLP attribute structures across the ingest path and would risk
dropping legitimate trace data. The accepted tradeoff is that capture is
**local-only and developer-gated**; turning that into a redaction/threat-model
guarantee is a deferred decision. Treat exported trace files as sensitive.
## Where it shows up in the UI
Dev mode only. The dev-tools span viewer reads from the local observability
projection (`SpanCacheService`) and renders the per-topic tree. Outside
developer mode `buildTelemetry` returns `undefined`, so **no tracer is
attached at all** and the AI SDK emits no spans — there is nothing to
project. Viewer mount, tab close, and window close are not part of the
trace persistence lifecycle.
## Where to read more
- Code: `src/main/ai/observability/`
- Span projection: `src/main/ai/observability/cache/SpanCacheService.ts`
- AI SDK telemetry docs: https://ai-sdk.dev/docs/reference/ai-sdk-core/telemetry

View File

@@ -0,0 +1,123 @@
# Params Pipeline
## What it is
`buildAgentParams` (`src/main/ai/runtime/aiSdk/params/buildAgentParams.ts`) is the
single function that turns a (request, provider, model, assistant) tuple
into everything `Agent.stream()` needs:
```ts
interface BuiltAgentParams {
sdkConfig: SdkConfig // providerId + providerSettings + modelId
tools: ToolSet | undefined // active + meta-tools after defer
plugins: AiPlugin<any, any>[] // model-adapter plugins (ordered)
system: string | undefined // assembled system prompt
options: AgentOptions // headers, providerOptions, stopWhen, repair, telemetry
hookParts: ReadonlyArray<Partial<AgentLoopHooks>>
}
```
It is a pure async function — no class, no shared state. Callers (chat,
agent session, translate, prompt-only) shape their own `AiBaseRequest` and
hand it in.
## RequestFeature
The composition unit is `RequestFeature`
(`src/main/ai/runtime/aiSdk/params/feature.ts`):
```ts
interface RequestFeature {
readonly name: string
applies?(scope: RequestScope): boolean
contributeModelAdapters?(scope: RequestScope): AiPlugin<any, any>[]
contributeHooks?(scope: RequestScope): Partial<AgentLoopHooks>
}
```
`collectFromFeatures(scope, features)` calls each feature's `applies`
(default `true`), then collects its model adapters and hook parts. The
result feeds `plugins` and `hookParts` in `BuiltAgentParams`.
Order matters because AI SDK plugin order is significant. The list lives
in `src/main/ai/runtime/aiSdk/params/features/index.ts`:
```ts
export const INTERNAL_FEATURES = [
devtoolsFeature,
gatewayUsageNormalizeFeature,
modelParamsFeature,
pdfCompatibilityFeature, // must run before anthropicCacheFeature
reasoningExtractionFeature, // must run before simulateStreamingFeature
simulateStreamingFeature,
anthropicCacheFeature,
anthropicHeadersFeature,
openrouterReasoningFeature,
noThinkFeature,
qwenThinkingFeature,
skipGeminiThoughtSignatureFeature,
providerWebSearchFeature,
providerUrlContextFeature
]
```
Callers can append per-request `extraFeatures`; those run after the
internal set. (AiService's analytics is *not* one of these — it is injected
separately as a `hookParts` entry, not a `RequestFeature`.)
## RequestScope
All features receive the same read-only scope object built in
`buildAgentParams`:
```ts
interface RequestScope extends ToolApplyScope {
request, signal, registry, assistant, model, provider,
capabilities, // resolveCapabilities — see capabilities.ts
sdkConfig, endpointType, aiSdkProviderId,
requestContext, // RequestContext for tool execute()
mcpToolIds
}
```
Features must never mutate the scope. The scope IS shared across all
features for a single request, so any added field becomes part of the
contract — keep it minimal.
## Pipeline order
```
buildAgentParams(input)
├─ resolveSdkConfig → providerToAiSdkConfig + modelId
├─ canModelConsumeTools? → resolveTools (registry sync + defer)
│ └─ syncMcpToolsToRegistry (only servers owning a selected tool)
│ └─ registry.selectActive (per-entry applies)
│ └─ applyDeferExposition (defer pool → meta-tools + system section)
├─ resolveCapabilities → enableWebSearch / enableUrlContext / …
├─ resolveEffectiveEndpoint → endpointType (model > provider default)
├─ resolveAiSdkProviderId → adapter-family routing (see adapter-family.md)
├─ collectFromFeatures → plugins + hookParts
├─ assembleSystemPrompt → assistant prompt + deferred-tools header
└─ buildAgentOptions → providerOptions + customParameters split
+ headers + stopWhen + repair + telemetry
```
## customParameters split
User-supplied `assistant.customParameters` may contain AI-SDK standard
params (temperature, topP, etc.) **and** provider-scoped overrides.
`extractAiSdkStandardParams` separates them; standard params land on the
top-level `AgentOptions` (AI SDK forwards them to the model), provider
params merge into `providerOptions[aiSdkProviderId]` (after a
`mergeCustomProviderParameters` pass that respects existing capability
options).
## Where to read more
- Code: `src/main/ai/runtime/aiSdk/params/`
- Tests: `src/main/ai/runtime/aiSdk/params/__tests__/` (`assembleSystemPrompt`,
`collectFromFeatures`, `composeHooks`),
`src/main/ai/runtime/aiSdk/params/features/__tests__/`
- Tool defer: [Tool Registry](./tool-registry.md)
- Endpoint routing: [Adapter Family](./adapter-family.md)
- Hooks: [Agent Loop](./agent-loop.md)

View File

@@ -0,0 +1,137 @@
# Provider Resolution
## The problem this solves
A request needs to know which `@ai-sdk/*` package to import, with which
settings, hitting which URL. Three pieces of state determine that:
| Field | Lives on | Example |
|---|---|---|
| `provider.id` | `Provider` row | `minimax`, `silicon`, `my-relay` |
| `endpointType` | `model.endpointTypes[0]` or `provider.defaultChatEndpoint` | `openai-chat-completions`, `anthropic-messages` |
| `adapterFamily` | `provider.endpointConfigs[endpointType].adapterFamily` | `openai-compatible`, `anthropic`, `azure-responses` |
`adapterFamily` is the actual SDK selector. `provider.id` is the user-facing
identity. `endpointType` is the protocol family. The mapping is written
once at provider-creation time; runtime resolution is read-only.
See [Adapter Family](./adapter-family.md) for the full design.
## Resolver
`src/main/ai/provider/endpoint.ts` exposes three pure helpers:
```ts
resolveEffectiveEndpoint(provider, model): { endpointType, baseUrl }
resolveProviderVariant(baseProviderId, endpointType): AppProviderId
resolveAiSdkProviderId(provider, endpointType): AppProviderId
```
`resolveAiSdkProviderId` is the runtime hot-path entry. It reads
`provider.endpointConfigs[endpointType].adapterFamily`, applies the
variant suffix if the endpoint type has one, falls back to
`openai-compatible` when no family is set.
```ts
// Full resolver — 6 lines
export function resolveAiSdkProviderId(provider, endpointType) {
const adapterFamily = endpointType
? provider.endpointConfigs?.[endpointType]?.adapterFamily
: undefined
if (adapterFamily && adapterFamily in appProviderIds) {
return resolveProviderVariant(appProviderIds[adapterFamily], endpointType)
}
return appProviderIds['openai-compatible']
}
```
## Variants
Some bases expose variant ids (a different endpoint on the same base).
`resolveProviderVariant` knows two suffix rules and applies one only when
the resulting `<base>-<suffix>` id is actually registered — otherwise it
returns the base unchanged:
| Endpoint type | Suffix tried |
|---|---|
| `openai-chat-completions`, `ollama-chat` | `-chat` |
| `openai-responses` | `-responses` |
Variants registered today (declared in each provider extension's
`variants` array, `packages/aiCore/src/core/providers/core/initialization.ts`):
| Base | Variant id(s) |
|---|---|
| `openai` | `openai-chat` (the base `openai` is itself the Responses API) |
| `azure` | `azure-responses`, `azure-anthropic` |
| `xai` | `xai-responses` |
| `cherryin` | `cherryin-chat` |
`ollama` has no registered variant, so an `ollama-chat` endpoint resolves
to the base `ollama`. Likewise there is **no `openai-responses` variant**
(the base already is). `azure-anthropic` is not reached through the suffix
rule — it is selected inside `buildAzureConfig` when the model is a Claude
model (see below). `resolveProviderVariant(baseId, endpointType)` is
idempotent when the base id is already a variant.
## Provider config
`providerToAiSdkConfig(provider, model)`
(`src/main/ai/provider/config.ts`) returns
`{ providerId: AppProviderId, providerSettings: AppProviderSettingsMap[id] }`.
It calls `resolveAiSdkProviderId` internally, then dispatches through an
ordered `{ match, build }` table to build the provider-specific settings
object (apiKey, baseURL, organization, headers, ...). There is **no
"gateway" branch**.
The builder table (`config.ts`, first match wins):
| Match | Builder | Notes |
|---|---|---|
| `id === copilot` | `buildCopilotConfig` | async — fetches a Copilot token |
| `id === 'cherryai'` | `buildCherryAIConfig` | |
| `isOllamaProvider` | `buildOllamaConfig` | |
| `isAzureOpenAIProvider` | `buildAzureConfig` | returns `azure` / `azure-responses` / `azure-anthropic` (Claude on Azure) |
| `id === 'bedrock'` | `buildBedrockConfig` | |
| `id === 'google-vertex'` | `buildVertexConfig` | returns `google-vertex` or `google-vertex-anthropic` for Claude; leaves `baseURL` undefined when no host is configured so the SDK derives the aiplatform host |
| `provider.id === 'cherryin'` | `buildCherryinConfig` | matches the **provider id**, not the resolved variant — the default chat endpoint resolves to `cherryin-chat`, so an `id === 'cherryin'` check never fires; async — resolves relay base URLs |
| `id === 'newapi'` | `buildNewApiConfig` | |
| `id === 'aihubmix'` | `buildAiHubMixConfig` | |
| _(no match)_ | `buildGenericProviderConfig` / `buildOpenAICompatibleConfig` | generic fallback |
Several builders are `async` (Copilot token, CherryIN relay URLs), which is
why `providerToAiSdkConfig` returns a promise.
## Custom providers
`src/main/ai/provider/custom/`:
- **aihubmix** — multi-vendor relay. `provider.id='aihubmix'` but each
model carries `model.provider='aihubmix.<vendor>'`; the registry's
aggregator fallback uses the suffix to pick the right `toolFactory`.
- **newapi** — same shape, different relay.
Both register through `ProviderExtension.create(...)` with their own
`AppProviderSettings` shape.
## Provider extensions
`src/main/ai/provider/extensions/index.ts` registers every
`@ai-sdk/*` package Cherry uses with `ProviderExtension.create`. Each
extension declares:
- `name` (the `AppProviderId` for the base)
- `aliases` (alternate ids that normalize to `name`)
- `variants` (suffix entries — see above)
- `create` (the SDK's factory)
- `toolFactories` (per-capability factory functions for `webSearch` /
`urlContext` etc.)
- `supportsImageGeneration` (boolean flag)
## Where to read more
- Code: `src/main/ai/provider/`
- Tests: `provider/__tests__/endpoint.test.ts` (54 cases)
- Migration of legacy provider rows: `src/main/data/migration/v2/migrators/mappings/ProviderModelMappings.ts`
- Catalog (new installs): `packages/provider-registry/data/providers.json`
- Design: [Adapter Family](./adapter-family.md)

View File

@@ -0,0 +1,838 @@
# AiStreamManager
## What it is
`AiStreamManager` is the Main-process **active-stream registry** and the
broker for every stream event. It owns the full life cycle of an AI
streaming reply — from `sendMessages` until the assistant turn finishes
persisting — including multicast fan-out, reconnect, abort, abort-and-restart
steering, and persistence triggering.
The renderer no longer holds a direct reference to the stream. Closing a
window does not abort the stream; it continues on Main and persists
normally. When the user returns, `attach` re-subscribes and the
manager replays any chunks that landed in between.
**Key: `topicId`.** A topic has at most one active stream at a time;
"streaming" is one phase of a topic, and every subscriber on a topic is
equal — there is no "owner" window.
## Why it exists
v1 ran the stream lifecycle, fan-out, and persistence on the **renderer**,
which produced three structural bug classes:
- **Window-bound lifecycle** — unmounting the chat (topic switch, window
close, route change) cancelled the transport stream, which aborted the
upstream request and dropped the in-flight reply.
- **No reconnect** — `reconnectToStream()` always returned `null`, so
returning to a topic lost live progress until the row hit the DB.
- **Renderer-owned persistence** — the DB write lived in the renderer, so a
crash/close between stream-end and commit lost the reply.
**Goal:** move stream lifecycle, multicast fan-out, and persistence to Main;
the renderer's only job is rendering chunks. The sections below are the
reference for that Main-side design.
## Architecture
```
┌──────────────── Renderer ────────────────────────────────────┐
│ │
│ useChat({ id: topicId, transport: IpcChatTransport }) │
│ ├─ sendMessages → Ai_Stream_Open (topicId, trigger, userMessageParts, …)
│ ├─ reconnectToStream → Ai_Stream_Attach ({ topicId }) │
│ └─ abort signal → Ai_Stream_Abort ({ topicId }) │
│ │
│ History: useQuery('/topics/:id/messages') │
│ Topic-level state: useTopicStreamStatus → shared cache │
└──────────────────────────────────────────────────────────────┘
↕ IPC (all keyed by topicId)
┌──────────────── Main ────────────────────────────────────────┐
│ │
│ dispatchStreamRequest(manager, subscriber, req) │
│ │ pick first ChatContextProvider whose canHandle matches │
│ │ provider.prepareDispatch(subscriber, req, ctx) │
│ └ manager.send(prepared) │
│ │
│ AiStreamManager │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ activeStreams: Map<topicId, ActiveStream> │ │
│ │ listeners: Map<listenerId, StreamListener> │ │
│ │ executions: Map<modelId, StreamExecution> │ │
│ │ ├─ abortController / status │ │
│ │ └─ buffer (ring) + droppedChunks │ │
│ │ lifecycle: StreamLifecycle (chat or prompt) │ │
│ └────────────────────────────────────────────────────────┘ │
│ ↓ createAndLaunchExecution → runExecutionLoop │
│ AiService.streamText(request) → ReadableStream<UIMessageChunk> │
│ ↓ pipeStreamLoop (tees: broadcast + readUIMessageStream) │
│ │
│ terminal → dispatchToListeners → every StreamListener: │
│ WebContentsListener → wc.send(Ai_StreamDone) │
│ PersistenceListener → PersistenceBackend.persistAssistant
│ • MessageServiceBackend (SQLite tree) │
│ • TemporaryChatBackend (in-memory) │
│ • AgentSessionMessageBackend (agent-session DB) │
│ • TranslationBackend (translate row) │
│ TraceFlushListener → SpanCacheService.saveSpans(topicId)
│ ChannelAdapterListener → adapter.onStreamComplete │
│ SseListener → res.write('[DONE]') │
└──────────────────────────────────────────────────────────────┘
```
## Pub/sub model
The manager is a broker: one set of producers feeds it, one set of
consumers subscribes. The system uses the observer pattern, and splits
dispatch into two semantically distinct channels based on **payload
volume × audience width**.
### Producers
| Producer | Events | Source |
|---|---|---|
| `StreamExecution` loop | `UIMessageChunk` (per-chunk delta) | `AiService.streamText`'s `ReadableStream` |
| `AiStreamManager` (state machine) | topic-level status transitions | `send()``pending`, first chunk → `streaming`, three terminal handlers → `done` / `error` / `aborted`, `awaiting-approval` on `tool-approval-request` |
### Consumers
| Consumer | Events | Subscription |
|---|---|---|
| `WebContentsListener` | chunk + terminal | explicit `attach``ActiveStream.listeners` |
| `PersistenceListener` | terminal | built by the provider and added in `send()` |
| `TraceFlushListener` | terminal | built by chat / agent-session turn owners and added in `send()` |
| `ChannelAdapterListener` / `SseListener` | chunk + terminal | caller injects into `send()`'s `listeners` |
| UI indirect consumers (sidebar indicators, …) | topic status | `useSharedCache('topic.stream.statuses.${topicId}')` |
### Two channels: targeted listener dispatch vs SharedCache mirror
| | Targeted listener dispatch | SharedCache mirror |
|---|---|---|
| Transport | `Ai_StreamChunk` / `Ai_StreamDone` / `Ai_StreamError` | `cacheService.setShared('topic.stream.statuses.${topicId}', …)` → built-in `Cache_Sync` broadcast |
| Main-side registry | `ActiveStream.listeners: Map<listenerId, StreamListener>` | none — uses the generic `CacheService` infra |
| Subscriber API | `attach` to register, explicit `detach` | `useSharedCache('topic.stream.statuses.${topicId}')` by topicId |
| Per-event size | tens of bytes to KBs (10s/s) | tens of bytes (≤ 5 transitions per stream) |
| Audience | narrow (one window per listener typically) | wide (every sidebar / indicator across all windows) |
| Cost of irrelevant pushes | high (bandwidth + deserialization) | negligible |
### Channel selection rule
Choose by **consumer / producer fanout**:
- chunk stream: one execution produces it, only the window rendering
that topic needs it → **targeted listener dispatch**, no irrelevant
pushes.
- topic status: one transition, every UI mirror wants it → **SharedCache**,
reuse generic cache sync, no bespoke IPC.
### Rules that follow from the channel split
- **`Ai_Stream_Attach` is required.** The listener channel requires
explicit consumer registration; `attach` is the entry point and also
returns a compact replay to fill the "before I subscribed" gap.
- **Bootstrap needs no extra IPC.** A new window pulls all shared cache
entries via `Cache_GetAllShared` on mount; every
`topic.stream.statuses.${topicId}` entry comes through without a
bespoke snapshot IPC.
- **Snapshot vs delta race.** Handled by the shared cache sync layer
itself — initial pull and `Cache_Sync` delta share the Main-side
source of truth; late arrivals overwrite stale state.
- **Grace-period cleanup does NOT clear the SharedCache entry.** Terminal
values (`done` / `aborted` / `error`) stay so renderer-side consumers
(`useTopicDbRefreshOnTerminal`, `useChatWithHistory`, awaiting-approval
indicators, sidebar badges) can observe them. The fulfilled-badge gate
is a read-receipt: the entry's `lastCompletedAt` (bumped only on
`done`) compared against `topic.stream.last_seen_completion.${topicId}`
(cross-window shared cache, written when the user acknowledges).
Memory tier — both reset on app restart.
- **`PersistenceListener` placement.** Terminal-only consumer — doesn't
need chunk bandwidth → not added via `attach`; the provider includes
it in the `listeners` array passed to `send()`.
- **`TraceFlushListener` placement.** Terminal-only consumer that flushes
`SpanCacheService.saveSpans(topicId)` after a chat / agent turn completes.
It belongs with the turn owner (`PersistentChatContextProvider` or
`AgentSessionRuntimeService`), not inside `AiStreamManager` and not in
trace viewer UI.
## File layout
```
src/main/ai/
├── AiService.ts lifecycle service: streamText + non-streaming IPC gateway
└── runtime/aiSdk/
└── Agent.ts single-pass `Agent.stream` wrapper (see Agent Loop)
src/main/ai/streamManager/
├── AiStreamManager.ts the registry + execution loop + multicast
├── pipeStreamLoop.ts shared chunk-pipe primitive (used by AiStreamManager.runExecutionLoop)
├── buildCompactReplay.ts attach-time chunk compaction (merge text-delta / reasoning-delta)
├── types.ts ActiveStream / StreamExecution / StreamListener / timings
├── index.ts barrel
├── context/ per-topicId namespace dispatch
│ ├── ChatContextProvider.ts interface + PreparedDispatch
│ ├── dispatch.ts single manager.send entry; MainContinueConversationRequest
│ ├── PersistentChatContextProvider.ts uuid topics → SQLite tree
│ ├── TemporaryChatContextProvider.ts in-memory (TemporaryChatService)
│ ├── AgentChatContextProvider.ts `agent-session:` → agents DB
│ └── modelResolution.ts resolveModels / siblingsGroupId
├── lifecycle/ strategy: chat vs ad-hoc prompt
│ ├── StreamLifecycle.ts interface
│ ├── ChatStreamLifecycle.ts cross-window broadcast + 30 s grace period + attach
│ ├── PromptStreamLifecycle.ts silent, no attach, immediate eviction
│ └── index.ts barrel
├── listeners/
│ ├── WebContentsListener.ts chunks → renderer windows
│ ├── PersistenceListener.ts observer protocol + delegates to PersistenceBackend
│ ├── TraceFlushListener.ts terminal trace-cache flush to local history
│ ├── ChannelAdapterListener.ts text → Discord / Slack / Feishu
│ └── SseListener.ts UIMessageChunk → SSE response (API server)
└── persistence/
├── PersistenceBackend.ts strategy interface + statsFromTerminal projection
└── backends/
├── MessageServiceBackend.ts finalize a SQLite pending placeholder
├── TemporaryChatBackend.ts append to in-memory topic
└── TranslationBackend.ts attach `data-translation` part to a target message
```
Agent session persistence is implemented under `agentSession/persistence`
because it writes the agent-session domain tables.
## StreamListener interface
The manager treats every consumer through one interface; it dispatches
each event by calling these methods uniformly:
```typescript
interface StreamListener {
readonly id: string
onChunk(chunk: UIMessageChunk, sourceModelId?: UniqueModelId): void
onDone(result: StreamDoneResult): void | Promise<void> // { finalMessage?, status: 'success', ... }
onPaused(result: StreamPausedResult): void | Promise<void> // { finalMessage?, status: 'paused', ... }
onError(result: StreamErrorResult): void | Promise<void> // { finalMessage?, error, status: 'error', ... }
isAlive(): boolean
}
```
All three terminal shapes share the same `finalMessage?` field — the
`UIMessage` accumulated by `readUIMessageStream` in the execution loop.
Whether the stream ended naturally, was aborted, or errored, it's the
same variable, only the stop point differs. Earlier designs called the
error-path partial a `partialMessage`; this turned out to be just a
`finalMessage` that ended early. Unifying the shape means
`PersistenceBackend` needs one `persistAssistant` method, not separate
write paths per status.
### Built-in implementations
| Listener | Role | id | isAlive |
|---|---|---|---|
| **WebContentsListener** | chunks → renderer window | `wc:${wc.id}:${topicId}` | `!wc.isDestroyed()` |
| **PersistenceListener** | terminal write via strategy | `persistence:${backendKind}:${topicId}:${modelId ?? 'default'}` | always `true` |
| **TraceFlushListener** | terminal trace-cache flush | `persistence:trace:${topicId}` | always `true` |
| **ChannelAdapterListener** | text → IM platform | `channel:${channelId}:${chatId}` | `adapter.connected` |
| **SseListener** | API-server SSE passthrough | `sse:${uuid}` | `!res.writableEnded` |
### Unified liveness policy
`AiStreamManager.dispatchToListeners` is the single funnel for terminal
events (`onDone` / `onPaused` / `onError`). Per listener it:
- Calls `listener.isAlive()` before each broadcast — `false` removes the
listener from `stream.listeners` (cleans up dead consumers).
- Wraps each call in try/catch — one bad listener can't starve the rest.
- Logs by event name + listener id for easy triage.
`onChunk` keeps a synchronous contract (the execution loop can't `await`
a listener) so it inlines the loop instead of going through
`dispatchToListeners`, but the dead-listener cleanup is the same.
### PersistenceListener — strategy pattern
One listener + four backends:
```typescript
interface PersistenceBackend {
readonly kind: string // "sqlite" | "temp" | "agents-db" | "translation"
persistAssistant(input: {
finalMessage?: CherryUIMessage
status: 'success' | 'paused' | 'error'
modelId?: UniqueModelId
stats?: MessageStats
}): Promise<void>
afterPersist?(finalMessage: CherryUIMessage): Promise<void>
}
```
Backends expose **one** write method; the three statuses share its
shape. On the `error` branch, `PersistenceListener` folds the
`SerializedError` into a trailing `data-error` part on `finalMessage.parts`
and then calls `persistAssistant({ status: 'error' })`, so backends never
have to know how to encode an error into a UIMessage — they just write.
The listener owns the observer protocol: filter by `modelId`
(multi-model topics have one listener per execution), merge the error
part exactly once, swallow exceptions so they don't break downstream
dispatch, fire `afterPersist` only when `status === 'success'` and
`finalMessage` is present (best-effort). Adding a fifth storage path
(e.g. an outbox) is a 60-line backend, no listener boilerplate to copy.
## ActiveStream & StreamExecution
```typescript
interface ActiveStream {
topicId: string
executions: Map<UniqueModelId, StreamExecution> // 1 entry single-model, N multi-model
listeners: Map<string, StreamListener> // shared across executions
// 'pending' on creation; flips to 'streaming' on first chunk; derived
// from executions on terminal (done / aborted / error /
// awaiting-approval).
status: TopicStreamStatus
isMultiModel: boolean // fixed at create; tags onChunk's sourceModelId
lifecycle: StreamLifecycle // chat or prompt strategy
expiresAt?: number
cleanupTimer?: ReturnType<typeof setTimeout>
}
interface StreamExecution {
modelId: UniqueModelId
anchorMessageId?: string // placeholder id for submit/regen, anchor id for continue
abortController: AbortController
status: 'streaming' | 'done' | 'error' | 'aborted'
// Per-execution ring buffer for reconnect replay. Hitting
// `maxBufferChunks` drops the oldest entry and bumps `droppedChunks`.
// Independent buffers prevent a chatty model from evicting a slower
// model's replay (a shared buffer would).
buffer: StreamChunkPayload[]
droppedChunks: number
finalMessage?: CherryUIMessage
// Set the moment a `tool-approval-request` chunk arrives, cleared on
// response. Read by `resolveTerminalStatus` to surface
// `awaiting-approval` on the topic.
awaitingApproval?: boolean
error?: SerializedError
siblingsGroupId?: number
loopPromise: Promise<void> // awaited by onStop for graceful shutdown
// Transport-side timings owned by the execution loop — chunk-shape-agnostic.
// Semantic timings (firstTextAt / reasoning*) live on the listener
// that cares; see "Stats composition" below.
timings: TransportTimings
// OTel root span set as active context around runExecutionLoop so
// AI SDK spans become children. Created by the context provider.
rootSpan?: Span
}
interface TransportTimings {
readonly startedAt: number // execution loop entry
completedAt?: number // execution loop exit (both try and catch paths)
}
interface SemanticTimings {
firstTextAt?: number // first text-delta chunk (TTFT endpoint)
reasoningStartedAt?: number // first reasoning-* chunk
reasoningEndedAt?: number // first non-reasoning chunk after reasoning
}
```
Topic-level status is derived from executions, with `'pending'` as the
initial pre-first-chunk window:
- Created (`send()` returned) → `'pending'`
- Any execution emits its first chunk → `'streaming'`
- All terminal, all `done``'done'`
- All terminal, all `aborted``'aborted'`
- Has `error`, none `streaming``'error'`
- Any execution still has `awaitingApproval` true on a terminal topic → `'awaiting-approval'`
`pending → streaming` is a one-time transition (first chunk anywhere).
The terminal status is derived once when the last execution terminates.
### Stats composition — tokens + timings → MessageStats
**Ownership** (key invariant: manager does not peek at chunk payloads):
| Source field | Owner | Collected at |
|---|---|---|
| `TransportTimings.startedAt` | `AiStreamManager` | `createAndLaunchExecution` |
| `TransportTimings.completedAt` | `AiStreamManager` | `pipeStreamLoop`'s `broadcastCompletedAt` |
| `SemanticTimings.firstTextAt` | `PersistenceListener` | own `onChunk`, first `text-delta` |
| `SemanticTimings.reasoning*` | `PersistenceListener` | own `onChunk`, observing `reasoning-*` boundaries |
| Token metadata | `agentLoop` usage observer | `finish` chunk projects AI SDK `LanguageModelUsage``CherryUIMessageMetadata` |
The manager is chunk-shape-agnostic — multicast, reconnect, abort,
abort-and-restart steering, persistence-triggering, never "what is text /
what is reasoning". AI SDK chunk type changes (vNext renames) only touch
`PersistenceListener`; the manager stays stable.
**Final projection.** `statsFromTerminal(finalMessage, mergedTimings)`
is one function; the listener merges its `SemanticTimings` with
`result.timings` (transport) before calling it:
```typescript
// inside PersistenceListener
const mergedTimings = { ...result.timings, ...this.semanticTimings }
const stats = statsFromTerminal(finalMessage, mergedTimings)
await this.opts.backend.persistAssistant({ finalMessage, status, modelId, stats })
```
Projected `MessageStats` fields:
| Field | Source |
|---|---|
| `totalTokens / promptTokens / completionTokens / thoughtsTokens` | `finalMessage.metadata.*` |
| `timeFirstTokenMs` | `round(firstTextAt - startedAt)` |
| `timeCompletionMs` | `round(completedAt - startedAt)` |
| `timeThinkingMs` | **not projected** — wall-clock `reasoningEndedAt - reasoningStartedAt` can include interleaved tool exec; see the `TODO(message-stats-redesign)` note in `PersistenceBackend.ts` |
Backends never derive stats themselves; they just write `input.stats`.
One projection path, four backends, no duplication.
## Public API
```typescript
class AiStreamManager {
// Lifecycle container invokes with no args (DEFAULT_CONFIG); tests can
// override `gracePeriodMs`, `backgroundMode`, `maxBufferChunks`.
constructor(config?: Partial<AiStreamManagerConfig>)
readonly chatLifecycle: StreamLifecycle
// ── Single dispatch entry ─────────────────────────────────────────
// Live topic → inject (agent-session only: upsert listeners onto the
// running stream, models ignored). Otherwise → start (evict any
// grace-period stream, launch one execution per `models` entry). A live
// chat topic never reaches the inject branch — `dispatch` restarts it via
// `abortAndAwait` first. Multi-model is detected from `models.length > 1`.
send(input: SendInput): SendResult
// ── Ad-hoc prompt stream (translate / topic-naming / model probes)
// Bypasses the chat dispatcher; uses promptStreamLifecycle (silent, no
// attach, immediate eviction).
streamPrompt(input: {
streamId: string // doubles as topicId
uniqueModelId: UniqueModelId
prompt?: string
messages?: CherryUIMessage[]
listener: StreamListener | StreamListener[]
}): SendResult
// ── Subscription management ───────────────────────────────────────
attach(sender: WebContents, req: { topicId }): AiStreamAttachResponse
detach(sender: WebContents, req: { topicId }): void
addListener(topicId: string, listener: StreamListener): boolean
removeListener(topicId: string, listenerId: string): void
// ── Control ───────────────────────────────────────────────────────
abort(topicId: string, reason: string): void
// Abort a live turn and await its executions settling (partial persists as
// `paused`) before evicting — used by `dispatch` to restart a chat turn.
abortAndAwait(topicId: string, reason: string): Promise<void>
hasLiveStream(topicId: string): boolean
// ── Execution-loop callbacks (driven internally; public for tests) ─
onChunk(topicId, modelId, chunk): void
onExecutionDone(topicId, modelId): Promise<void>
onExecutionPaused(topicId, modelId): Promise<void>
onExecutionError(topicId, modelId, error): Promise<void>
// ── Inspection (read-only snapshot) ───────────────────────────────
inspect(topicId: string): TopicSnapshot | undefined
}
```
### `send` contract
```typescript
interface SendInput {
topicId: string
models: ReadonlyArray<{ modelId: UniqueModelId; request: AiStreamRequest; rootSpan?: Span }>
listeners: StreamListener[]
userMessage?: Message // persisted user row; not consumed by send() — callers' bookkeeping
siblingsGroupId?: number
lifecycle?: StreamLifecycle // omit → chatLifecycle; streamPrompt passes promptStreamLifecycle
}
interface SendResult {
mode: 'started' | 'injected'
executionIds: UniqueModelId[] // started → fresh ids; injected → already running
}
```
- **injected**: topic has a live stream (`pending` or `streaming`) →
`models` is ignored and `listeners` upsert by id; **no message is
injected**. Only agent-session topics reach this branch (the dispatcher
aborts+restarts a live chat topic before calling `send()`); the
agent-session follow-up was already enqueued on the session's
`pendingTurns` by its provider.
- **started**: topic is idle or grace-period (terminal) → any leftover
grace-period stream is evicted, a new `ActiveStream` is created with
`isMultiModel = models.length > 1`, one execution launched per model.
`isMultiModel` is not an input — it's derived from `models.length`.
### Execution loop — `runExecutionLoop` + `pipeStreamLoop`
Each execution runs an independent loop that bridges "the single
`ReadableStream` from AI SDK" to "what the manager has to do":
broadcast to listeners, buffer for reconnect, and accumulate a
persistable `finalMessage`.
**Step 1 — get the raw chunk stream.**
```typescript
const stream: ReadableStream<UIMessageChunk> = await aiService.streamText({
...request,
requestOptions: { ...request.requestOptions, signal }
})
```
`streamText` returns AI SDK's raw chunk stream. `signal` comes from
`StreamExecution.abortController`; `abort()` triggers it.
**Step 2 — wrap with `withIdleTimeout`.** Resets per chunk; on idle
timeout it aborts `exec.abortController`, which the upstream request is
already wired to.
**Step 3 — `pipeStreamLoop` tees the chunk stream.**
`pipeStreamLoop` is the shared chunk-pipe primitive (the one
`AiStreamManager.runExecutionLoop` uses). It `tee()`s the stream into two
independent branches:
| Branch | Consumer | Purpose |
|---|---|---|
| Broadcast | `onChunk(topicId, modelId, chunk)` per chunk | Buffer into `exec.buffer` (ring), fan out to every listener |
| Accumulator | `readUIMessageStream` | Each yielded snapshot is written to `exec.finalMessage`; at stream end it's the final message |
The accumulator reader is **not** cancelled directly on abort —
`Agent.stream` honours the same signal upstream and propagates `done`
through `tee()`, so the accumulator drains naturally. Cancelling the
accumulator reader directly would race AI SDK's internal
`controller.close()` and produce an `ERR_INVALID_STATE`
unhandledRejection.
**Step 4 — terminal dispatch.**
| Exit path | Handler | Behaviour |
|---|---|---|
| Normal end | `onExecutionDone` | `exec.status = 'done'`, finalMessage persisted as `success` |
| `signal.aborted` + `exec.status === 'aborted'` | `onExecutionPaused` | (Possibly partial) finalMessage persisted as `paused` |
| `streamErrorText` (in-stream `error` chunk) | `onExecutionError` | Error part folded into finalMessage, persisted as `error` |
| Pre-stream or broadcast throw | `onExecutionError` | Same — error part folded, persisted |
## Lifecycle strategy — chat vs prompt
The manager stays policy-free. Behaviour that differs between chat
streams and one-shot ad-hoc prompts (translate, topic-naming, model
probes) lives in `StreamLifecycle`:
```typescript
interface StreamLifecycle {
readonly name: string
onCreated(stream): void // freshly registered
onPromotedToStreaming(stream): void // first chunk
onTerminal(stream): void // every isTopicDone
canAttach(stream): boolean // gate for `attach`
cleanup(stream, evict: () => void): void // when to remove from activeStreams
}
```
| | `ChatStreamLifecycle` | `PromptStreamLifecycle` |
|---|---|---|
| Status broadcast | writes `topic.stream.statuses.<topicId>` on `pending → streaming → terminal` (with `awaitingApprovalAnchors` derived from `exec.awaitingApproval`) | none |
| `canAttach` | `true` | `false` |
| `cleanup` | sets a `setTimeout(evict, gracePeriodMs)`; chat reconnects within 30 s | calls `evict()` immediately |
`send()` defaults to `chatLifecycle`; `streamPrompt()` passes
`promptStreamLifecycle`.
## Multi-model
User mentions multiple models for one turn:
```
User: "Explain quantum mechanics" @gpt-4o @claude-sonnet
PersistentChatContextProvider.prepareDispatch
├─ persist user message (tree node)
├─ resolveModels → [gpt-4o, claude-sonnet]
├─ siblingsGroupId = (monotonic counter)
├─ create one pending assistant placeholder per model (SQLite)
├─ build listeners: subscriber + 2 PersistenceListener (one per backend)
├─ build models: 2 × { modelId, request, rootSpan }
└─ return PreparedDispatch
dispatchStreamRequest → manager.send({ models, listeners, siblingsGroupId })
├─ create ActiveStream (isMultiModel = true, 2 executions)
├─ launch one execution loop per model, each with its own
│ ring buffer
└─ return { mode: 'started', executionIds: [gpt-4o, claude-sonnet] }
```
## Steering
Steering a chat turn is **abort-and-restart**, not mid-turn injection. When a
new `Ai_Stream_Open` arrives for a topic that is still streaming,
`dispatchStreamRequest` calls `manager.abortAndAwait(topicId, 'steer-restart')`:
it aborts every execution, awaits their loops settling (each partial persists
as `paused` via the normal terminal path), and evicts the stream — then the
following `send()` starts a fresh turn. Awaiting settlement before re-dispatch
is what avoids an orphaned `pending` row and the same-`(topic, model)` race a
synchronous abort+restart would create.
Agent-session topics are the exception: they are **not** aborted. The follow-up
is enqueued on the session's `pendingTurns` and the running turn is interrupted
between tool calls; `send()` only upserts the new subscriber. See
[Agent Session Runtime → Live follow-up](./agent-session-runtime.md#live-follow-up).
## End-to-end flows
One row per flow. The two with dedicated docs are cross-linked rather than
duplicated; the rest are stream-manager-specific.
| Flow | Trigger | Mechanism | Terminal / result |
|---|---|---|---|
| Submit (standard) | `Ai_Stream_Open` | `dispatchStreamRequest``prepareDispatch` (persist user msg, reserve placeholders, build listeners + models) → `manager.send` → N × `runExecutionLoop` | `Ai_StreamDone`; `PersistenceListener.persistAssistant`; chat lifecycle `scheduleCleanup(30 s)` |
| Steering — chat resubmit | `Ai_Stream_Open` on a live chat topic | `dispatch``manager.abortAndAwait` (abort execs, await settle as `paused`, evict) → `manager.send` starts a fresh turn | prior partial persisted as **`paused`**; new turn streams — see [Steering](#steering) |
| Agent-session follow-up | `Ai_Stream_Open` on a live `agent-session:*` topic | provider persists the user row, `enqueueUserMessage``pendingTurns`, interrupt-when-safe; `manager.send` upserts the subscriber → `{ mode: 'injected' }` | next turn starts from `pendingTurns` — see [Agent Session Runtime](./agent-session-runtime.md#live-follow-up) |
| Tool-approval pause+resume | approval-request chunk → `awaiting-approval` | decision via `Ai_ToolApproval_Respond`; Claude-Agent unblocks `canUseTool`, MCP dispatches `continue-conversation` | card clears when the resumed stream broadcasts `pending` — see [Tool Approval](./tool-approval.md) |
| Reconnect | `Ai_Stream_Attach` on mount | `manager.attach`: `not-found` / streaming (register listener + compact replay) / done-paused (`finalMessage(s)`) / error | live chunks resume, or the final row is returned |
| Abort — user stop | `Ai_Stream_Abort` | per exec: `abortController.abort` → loop `signal` aborts → broadcast reader `cancel` → read loop `done` | partial persisted as **`paused`**; topic status → `aborted` (or `awaiting-approval` if an exec had it set) |
| Abort — no subscribers | last `WebContentsListener` dies + `backgroundMode === 'abort'` | `onChunk` prunes dead listeners; `listeners.size === 0` → auto `abort(topicId, 'no-subscribers')` | partial persisted as **`paused`** — never silently `success` or leaked |
| Multi-window | window B opens a live topic | B sends `Ai_Stream_Attach` → compact replay + its own `WebContentsListener`; each chunk fans out to A and B | both windows render the same chunks in sync |
| Channel / Agent | `AiStreamManager.send` in-process (no IPC) | scenario differs only by listener composition (table below) | per-listener effect |
**Topic status needs no `attach`.** Observers that only care "is this topic
live?" (sidebar loading indicators, topic-list status dots) don't register a
`WebContentsListener`. Every status transition writes the SharedCache key
`topic.stream.statuses.${topicId}`; observers read it via `useSharedCache`
directly. `Ai_Stream_Attach` is only needed when a window wants live chunks.
### Channel / Agent listener composition
Channel adapters and the agent scheduler call `AiStreamManager.send`
directly inside Main — no IPC. The scenario differences are entirely in the
listener composition:
| Scenario | Listeners | Effect |
|---|---|---|
| Renderer user message | `WebContentsListener` + `PersistenceListener` | live UI + persist |
| Channel bot reply | `ChannelAdapterListener` + agent-session persistence listener | IM send + agents DB |
| Channel + user both watching | above + `WebContentsListener(B)` | parallel fan-out |
| API server SSE | `SseListener` + `PersistenceListener` | SSE push + persist |
| Translate | `WebContentsListener` + `PersistenceListener(TranslationBackend)` | live overlay + writes `data-translation` part on success |
## IPC contract
### Request channels (Renderer → Main)
| Channel | Payload | Response | Semantics |
|---|---|---|---|
| `Ai_Stream_Open` | `AiStreamOpenRequest` (`submit-message` \| `regenerate-message`) | `{ mode, executionIds?, userMessageId?, placeholderIds? }` | Open / inject; provider routes by topicId |
| `Ai_Stream_Attach` | `{ topicId }` | `AiStreamAttachResponse` | Subscribe; returns compact replay when streaming |
| `Ai_Stream_Detach` | `{ topicId }` | void | Unsubscribe (stream continues) |
| `Ai_Stream_Abort` | `{ topicId }` | void | Stop current generation |
> Topic status snapshots need no dedicated IPC: a new window pulls every
> `topic.stream.statuses.${topicId}` entry via `Cache_GetAllShared` on
> mount, and `useSharedCache` subscribes by topicId.
### Push channels (Main → Renderer)
| Channel | Payload | Notes |
|---|---|---|
| `Ai_StreamChunk` | `{ topicId, executionId?, chunk }` | Multi-model carries `executionId`; **only sent to attached windows** |
| `Ai_StreamDone` | `{ topicId, executionId?, status, isTopicDone }` | `status ∈ { 'success', 'paused' }` — natural completion vs user abort; **only sent to attached windows** |
| `Ai_StreamError` | `{ topicId, executionId?, isTopicDone, error }` | `SerializedError`; **only sent to attached windows** |
Topic-level status transitions are NOT a bespoke IPC — they live in the
SharedCache key `topic.stream.statuses.${topicId}` (Main `setShared`
built-in `Cache_Sync` broadcast). The entry shape is
`TopicStatusSnapshotEntry`:
```typescript
{
status: 'pending' | 'streaming' | 'done' | 'aborted' | 'awaiting-approval' | 'error'
activeExecutions: ActiveExecution[] // execs currently `streaming`
awaitingApprovalAnchors: ActiveExecution[] // execs with awaitingApproval = true
lastCompletedAt?: number // bumped only on `done`; the fulfilled-badge read-receipt gate
}
```
`pending` doubles as the "new stream just created" signal — the old
`Ai_StreamStarted` IPC is gone. Grace-period cleanup does NOT clear the
entry — terminal values (`done` / `aborted` / `error`) stay so renderer
consumers (DB-refresh trigger, awaiting-approval indicators, sidebar
badges) can observe them. The badge "should I show this?" gate is a
read-receipt: `entry.lastCompletedAt` (authoritative, bumped only on
`done`) compared against `topic.stream.last_seen_completion.${topicId}`
(cross-window shared cache, written by the renderer when the user
acknowledges).
**All traffic is keyed by topicId**; multi-model uses `executionId` to
demux chunks per model.
**Topic status vs message status.** Don't conflate:
- **Topic stream status** (SharedCache `topic.stream.statuses.${topicId}`):
one entry per topic, source of truth is `ActiveStream.status`, valid
only while the `ActiveStream` exists (+ grace period).
- **Assistant message status** (`AssistantMessageStatus`: `PENDING` /
`PROCESSING` / `SUCCESS` / `ERROR`): one per assistant message,
persisted in SQLite, written by `PersistenceListener.onDone/onError`.
In multi-model, a single topic-level transition corresponds to N
separate message rows.
## ChatContextProvider — per-topicId namespace dispatch
`Ai_Stream_Open` is handled in Main by `dispatchStreamRequest`
(`context/dispatch.ts`):
```
dispatchStreamRequest(manager, subscriber, req)
→ provider = providers.find(p => p.canHandle(req.topicId))
→ prepared = await provider.prepareDispatch(subscriber, req, { hasLiveStream })
→ result = manager.send(prepared) // ← the only manager.send call
→ return { mode, executionIds?, userMessageId?, placeholderIds? }
```
Providers only "prepare" — they never call `manager.send` directly. Two
benefits:
- Provider unit tests assert on `PreparedDispatch` shape without mocking
the manager.
- The restart / start / multi-model fan-out routing lives in exactly one
place.
### Provider interface
```typescript
interface ChatContextProvider {
readonly name: string
canHandle(topicId: string): boolean
prepareDispatch(
subscriber: StreamListener,
req: MainDispatchRequest,
ctx: { hasLiveStream: boolean }
): Promise<PreparedDispatch>
}
interface PreparedDispatch {
topicId: string
models: ReadonlyArray<{ modelId: UniqueModelId; request: AiStreamRequest; rootSpan?: Span }>
listeners: StreamListener[] // subscriber + per-execution PersistenceListener(s)
userMessage?: Message
userMessageId?: string
siblingsGroupId?: number
isMultiModel: boolean
lifecycle?: StreamLifecycle
}
// dispatch.ts also accepts a Main-internal `continue-conversation`
// variant synthesised by the tool-approval IPC handler — not exposed
// over the renderer ↔ main contract.
type MainDispatchRequest = AiStreamOpenRequest | MainContinueConversationRequest
```
### Built-in providers
| Provider | `canHandle` | Data layer | User message | Assistant message |
|---|---|---|---|---|
| **AgentChatContextProvider** | `topicId.startsWith('agent-session:')` | `agentMessageRepository` | written upfront | runtime provides `PersistenceListener(AgentSessionMessageBackend)` |
| **TemporaryChatContextProvider** | `temporaryChatService.hasTopic(topicId)` | `TemporaryChatService` (in-memory) | appended upfront | `PersistenceListener(TemporaryChatBackend)` appends on done |
| **PersistentChatContextProvider** | `true` (catch-all) | `messageService` + SQLite | transactional create | `PersistenceListener(MessageServiceBackend)` updates pending on done |
Order: Agent → Temporary → Persistent (first `canHandle === true`
wins).
### Persistence path comparison
| | Persistent | Temporary | Agent |
|---|---|---|---|
| User message timing | before stream (tree node) | before stream (append) | before stream (agents DB) |
| Assistant placeholder | created pending before stream | none | created pending before stream (atomic with user msg) |
| Terminal write | `update` placeholder | `append` new row | `update` placeholder (`persistAssistant`) |
| Backend | `MessageServiceBackend` | `TemporaryChatBackend` | `AgentSessionMessageBackend` |
| Multi-model | ✓ | ✗ (single-model) | ✗ (single-model) |
| Regenerate | ✓ | ✗ | ✗ |
### One PersistenceListener across all topic kinds
Persistent / Temporary / Agent / Translation all share the same
`PersistenceListener` class — only the injected `PersistenceBackend`
differs. The observer protocol (`modelId` filter, error part folding,
skip-when-no-finalMessage, swallow errors) is implemented once.
## AiService integration
`AiService` is a lifecycle service:
- **Streaming.** `streamText(request)` returns
`Promise<ReadableStream<UIMessageChunk>>`, consumed by
`AiStreamManager.runExecutionLoop`.
- **Non-streaming IPC gateway.** `generateText` / `checkModel` /
`embedMany` / `generateImage` / `listModels`, registered as IPC
handlers in `onInit`.
`AiStreamManager` calls `await application.get('AiService').streamText(...)`.
Pre-stream errors (provider / model resolution, agent param build)
reject the returned Promise; mid-stream errors come through the returned
stream's error path — the two error paths never overlap.
## Grace period & reconnect
After a stream terminates, `ActiveStream` stays in memory for 30 s
(`config.gracePeriodMs`). During that window a returning user can
`attach` and pull `finalMessage` without a DB read. After expiry the
entry is evicted; subsequent `attach` returns `not-found` and the
renderer reads from the DB through `useQuery` (PersistenceListener has
already written by then).
If the user stops and immediately retries on the same topic, `send`
takes the start branch: `evictStream` first clears the grace-period
remnant (cancels the cleanup timer and drops the entry from
`activeStreams`), then the new stream is created — the old never blocks
the new.
## Edge case cheat sheet
| Case | Handling |
|---|---|
| User sends again on the same topic mid-stream (chat) | `dispatch` calls `abortAndAwait` (prior turn persists as `paused`), then `send` starts a fresh turn |
| Retry immediately after stream ends | `send` takes start; `evictStream` clears the grace-period entry first |
| Window closes mid-stream | Next broadcast sees `WebContentsListener.isAlive() === false` and removes it; `PersistenceListener` doesn't depend on a window |
| All windows closed + `backgroundMode='continue'` | Stream continues; `PersistenceListener` persists when done |
| All windows closed + `backgroundMode='abort'` | `onChunk` finds `stream.listeners.size === 0``abort(topicId, 'no-subscribers')`; partial persisted as `paused` |
| Multi-window on same topic | Each window has its own `WebContentsListener`; chunks fan out to all alive listeners |
| Same window re-attaches | Listener id is stable (`wc:${wc.id}:${topicId}`); `addListener` upserts by id |
| Attach mid-stream | `attach` returns compact replay per execution (each buffer compacted independently); observer fills in the gap |
| Ring buffer overflow | At `maxBufferChunks` the oldest chunk drops and `droppedChunks++`; subsequent attach logs the total dropped — replay is no longer lossless |
| Multi-model + resubmit | `abortAndAwait` settles all executions before restart; no orphaned `pending` row |
| Stream emits `tool-approval-request` | `exec.awaitingApproval = true`; on stream end the topic surfaces `awaiting-approval` via the shared cache |
| Main process restart | `activeStreams` clears; in-flight streams are lost; the renderer re-reads from the DB |
## Design notes
### Testing strategy
- **Manager tests.** `new AiStreamManager({ maxBufferChunks: 3 })` via
the optional config arg; state assertions go through `mgr.inspect(topicId)`;
listener upsert / abort / backgroundMode are tested via behaviour
(drive a chunk, assert which listeners received it).
- **Provider tests.** Assert on the returned `PreparedDispatch` shape
directly — no manager mock.
- **PersistenceListener tests.** `TemporaryChatBackend` as the test
vehicle covers the observer protocol once for every backend.
- All internal state has a public inspection API; production and tests
share the same contract.

View File

@@ -0,0 +1,75 @@
# Tool Approval
## Model
Main is the single writer of approval state. The renderer surfaces an
`approval-requested` ToolUIPart, takes the user's decision, and posts it
to Main. Main applies the decision to the DB-authoritative anchor parts,
persists, and resumes the stream.
## End-to-end flow
1. **Tool needs approval** — at `execute` time, the wrapper checks
`tool.needsApproval` and the assistant's auto-approve policy. If
approval is required, the wrapper writes an `approval-requested` part
and resolves the tool's promise into a held state (Claude-Agent: holds
`canUseTool`; MCP: stream pauses on the approval part).
2. **Stream pauses**`AiStreamManager` transitions the topic to
`awaiting-approval`. The `topic.stream.statuses.<topicId>` shared-cache
entry carries the status; every renderer window reading that key sees
the pause atomically.
3. **User decides** — the approval card renders from the part. On click,
`useToolApprovalBridge` (`src/renderer/hooks/useToolApprovalBridge.ts`)
calls `window.api.ai.toolApproval.respond(...)` with `approvalId`,
`approved`, optional `reason` / `updatedInput`, `topicId`, `anchorId`.
4. **Main applies**`AiService`'s `Ai_ToolApproval_Respond` handler
branches on transport **before** touching the DB:
- **Claude-Agent fast-path** (`AiService.ts:191-197`): hands the
decision to `AgentSessionRuntimeService.respondToolApproval`, which
resolves the live `canUseTool` promise so the existing stream
proceeds. When a live registry entry handles it, the handler
**early-returns — no DB read happens** (and `topicId` / `anchorId`
are not required).
- **MCP path** (reached only when no live entry matched; requires
`topicId` + `anchorId`): reads the anchor message's current `parts`
from DB, applies the decision, and **writes only when the target
`approval-requested` part is present on the DB row** — guarding the
overlay-only case (approval received before the part has persisted).
When all approvals on the turn are decided it dispatches a synthetic
`continue-conversation` request through `dispatchStreamRequest`; the
provider applies the decision when it reads parts.
5. **Awaiting-approval clears** — the moment the continue stream
broadcasts `pending`, the shared-cache entry flips back. Every window
sees the approval card disappear in the same tick.
## Persistent decisions
`useToolApproval` (`src/renderer/pages/home/Messages/Tools/hooks/useToolApproval.ts`)
exposes an `autoApprove` action **only for MCP tools** — when an `mcpTool`
descriptor is passed. It persists the opt-out by PATCHing the server's
`disabledAutoApproveTools`, so the MCP settings page reflects it and
subsequent calls of that tool skip the approval card. There is no generic
per-tool default for non-MCP (e.g. Claude-Agent) tools.
## Why this design
- **No renderer writes** — the renderer cannot PATCH approval state. If
it did, it would race Main's authoritative re-read and cause the
approval card to reappear on every click.
- **Cross-window consistency** — the shared-cache `awaiting-approval`
status is the single source of truth for "this topic is paused".
- **Overlay/persist gap** — the renderer sometimes sees the
`approval-requested` part via overlay before it lands in the DB row.
Writing unconditionally would clobber the (concurrent) Main-side
persistence; the conditional write + continue-dispatch covers that case.
## Where to read more
- Main IPC handler: `src/main/ai/AiService.ts` (`Ai_ToolApproval_Respond`)
- Renderer bridge: `src/renderer/hooks/useToolApprovalBridge.ts`
- Persistent decisions: `src/renderer/pages/home/Messages/Tools/hooks/useToolApproval.ts`
- Status broadcast: [Stream Manager](./stream-manager.md)

View File

@@ -0,0 +1,142 @@
# Tool Registry
## Model
```ts
interface ToolEntry {
name: string // wire-name, what the LLM emits in tool_calls
namespace: string // grouping for `tool_search` (web, kb, mcp:<id>, meta)
description: string // one-line summary for `tool_search`
defer: 'never' | 'always' | 'auto'
tool: Tool // AI SDK Tool (schema + execute + needsApproval + toModelOutput)
applies?(scope): boolean
}
```
`registry` (`src/main/ai/tools/adapters/aiSdk/registry.ts`) is a
process-wide singleton. Tool files register at module-import time; the
registry is read at request time by `buildAgentParams`. The Claude Code
runtime has a *separate* tool system — `tools/adapters/claudeCode/agentTools.ts`
builds its descriptors from MCP servers and built-in descriptors directly;
it does not consume this aiSdk `ToolRegistry`.
Tests construct their own `new ToolRegistry()` to avoid singleton pollution.
## Wire-name convention
Double underscore is the segment separator (so internal single `_` stays
unambiguous):
| Source | Name pattern | Example |
|---|---|---|
| Built-in | `<namespace>__<verb>` | `web__search`, `kb__search` |
| MCP | `mcp__<camelCase(server)>__<camelCase(tool)>` | `mcp__gmail__sendMessage` |
| Meta | `tool_<verb>` | `tool_search`, `tool_invoke`, `tool_inspect` (`tool_exec` is defined but not injected — see below) |
## Built-in tools
`src/main/ai/tools/adapters/aiSdk/builtin/` registers **four** entries:
- `web__search` (`WebSearchTool.ts``createWebSearchToolEntry`) — namespace
`web`. Talks to the configured web-search provider via the
renderer-shared search service.
- `web__fetch` (`WebSearchTool.ts``createWebFetchToolEntry`) — namespace
`web`. Fetches a URL's content.
- `kb__search` (`KnowledgeSearchTool.ts`) — semantic search over the active
knowledge base.
- `kb__list` (`KnowledgeListTool.ts`) — enumerate available knowledge bases /
documents.
Registration happens in `builtin/index.ts` (`registerBuiltinTools`). Each
tool's `applies` gates on the relevant `assistant.settings.*` flag (e.g.
`enableWebSearch`).
## MCP tools
`src/main/ai/tools/adapters/aiSdk/mcp/`:
- `resolveAssistantMcpToolIds` — assistant's enabled MCP servers + per-tool
disable list → set of tool ids.
- `mcpTools.syncMcpToolsToRegistry({ selectedToolIds })` — calls
`listTools` on each MCP server that owns at least one selected tool,
registers each as a `ToolEntry` whose `tool.execute` proxies through
the MCP transport. **Scope:** only servers owning a selected tool are
hit — avoids paying the per-server round-trip when only one MCP tool
is in use for this request.
The sync is idempotent; a stale entry is overwritten on the next sync.
## Meta-tools
`src/main/ai/tools/adapters/aiSdk/meta/` defines four tools that turn the
registry into a search-then-call interface for the model. Only the first
three are injected:
| Tool | Injected? | Use |
|---|---|---|
| `tool_search` | yes | Browse the deferred pool by namespace + query, returns brief descriptions |
| `tool_inspect` | yes | Emit a JSDoc stub for one tool — enough to call it correctly |
| `tool_invoke` | yes | Invoke any registry tool by name with a JSON arg blob |
| `tool_exec` | **no** | Sandboxed JS exec with the full registry as a global API (`meta/exec/runtime.ts`, `meta/exec/worker.ts`) — defined but intentionally not injected |
The injected three are added to the tool set by `applyDeferExposition` when
(and only when) the request actually defers tools. See below.
## Defer exposition
`src/main/ai/tools/adapters/aiSdk/exposition/`:
- `shouldDefer(entries, contextWindow)` — returns the set of names to
defer. Two gates above the simple threshold:
- **MIN_AUTO_DEFER_COUNT** — the auto pool must be large enough that
search-then-invoke beats inlining.
- **META_TOOLS_OVERHEAD_TOKENS** — estimated savings must exceed the
meta-tools' static prompt cost. Without these gates, small tool sets
+ small-context models trigger defer and pay net-negative tokens.
- `applyDeferExposition(tools, registry, contextWindow)` — strips the
deferred names out of `tools`, injects `tool_search` / `tool_inspect` /
`tool_invoke`, and returns the entries the system-prompt's
`<DEFERRED_TOOLS>` section needs to enumerate (so the model knows what
namespaces exist).
**Approval-gated tools are never deferred.** A force-prompt MCP tool is registered
with `defer: 'never'``mcp/mcpTools.ts` reads `isMcpToolForcePromptBySource` once
to drive both `defer` and `needsApproval` — so it stays inline and the SDK's native
approval gate fires on it. Deferring it would drop it from the SDK tool-set, so the
gate would never fire and it would be reachable only through `tool_invoke` with no
approval card. As a runtime backstop the `tool_invoke` / `tool_exec` meta-tools also
call `isApprovalGated` at execution time and refuse a gated tool (covering the
`registry.getByName(any-name)` vector), steering the model to call it inline. See
[Tool Approval](./tool-approval.md).
`tool_exec` is **not injected** by `applyDeferExposition` — there is no
`metaTools.exec` flag. The injection site (`applyDeferExposition.ts:50-53`)
deliberately leaves it out: its `worker_threads` + `new Function` sandbox
runs model-authored code with full Node privileges, a privilege-escalation
surface vs the renderer's prior restrictions. It is meant to be re-enabled
behind an explicit Preference key once there is a concrete need.
## `applies` and tool-call repair
- `applies(scope: ToolApplyScope)` — per-entry predicate consulted at
`registry.selectActive`. Throws are caught and treated as "inactive"
with a warning log.
- `createAiRepair(...)` (`tools/adapters/aiSdk/repair.ts`) — passed to AI SDK as
`experimental_repairToolCall`. When the model emits **malformed args**
(`InvalidToolInputError`), the repair function gets one chance to fix it via a
follow-up LLM call. Other failures (e.g. an unknown tool name) are
returned unrepaired.
## Where to read more
- Code: `src/main/ai/tools/adapters/aiSdk/` (Claude Code adapter:
`src/main/ai/tools/adapters/claudeCode/`)
- Tests: `tools/adapters/aiSdk/__tests__/`,
`tools/adapters/aiSdk/builtin/__tests__/`,
`tools/adapters/aiSdk/exposition/__tests__/`,
`tools/adapters/aiSdk/mcp/__tests__/`,
`tools/adapters/aiSdk/meta/__tests__/`
- Defer rationale, gate thresholds:
`tools/adapters/aiSdk/exposition/shouldDefer.ts` (header doc + tests)
- Approval flow: [Tool Approval](./tool-approval.md)

View File

@@ -130,7 +130,7 @@ User Message
└── Response Pipeline ──→ Message blocks (text, code, image, tool-call)
```
See [AI Core Architecture](./ai-core-architecture.md) for the complete data flow.
See [AI Reference](./ai/README.md) for the complete data flow.
## Monorepo Structure
@@ -142,7 +142,7 @@ cherry-studio
│ │ ├── data/ # Data layer (DB, Cache, Preference, DataApi)
│ │ ├── services/ # 27 lifecycle-managed services
│ │ ├── knowledge/ # RAG / knowledge base
│ │ ├── mcpServers/ # Built-in MCP servers
│ │ ├── ai/mcp/servers/ # Built-in MCP servers
│ │ ├── apiServer/ # Local REST API (Express)
│ │ └── integration/ # External integrations
│ │
@@ -179,8 +179,8 @@ cherry-studio
|-----------|----------|---------------|
| Service Lifecycle | `src/main/core/lifecycle/` | [Lifecycle Reference](./lifecycle/README.md) |
| Data Layer | `src/main/data/` | [Data Reference](./data/README.md) |
| AI Core | `src/renderer/aiCore/` | [AI Core Architecture](./ai-core-architecture.md) |
| MCP (Tool Use) | `src/main/services/mcp/` | — |
| AI Core | `src/main/ai/` | [AI Reference](./ai/README.md) |
| MCP (Tool Use) | `src/main/ai/mcp/` | — |
| Knowledge (RAG) | `src/main/knowledge/` | [KnowledgeService](./knowledge/knowledge-service.md) |
| Message System | `src/renderer/store/` | [Message System](./messaging/message-system.md) |
| CherryClaw (Agent) | `src/main/services/agents/` | [CherryClaw Overview](./cherryclaw/overview.md) |

View File

@@ -1,6 +1,6 @@
# Claw MCP Server
The Claw MCP server is a built-in MCP (Model Context Protocol) server automatically injected into every CherryClaw session. It provides four self-management tools for the agent: `cron` (task scheduling), `notify` (notifications), `skills` (skill management), and `memory` (memory management).
The Claw MCP server is a built-in MCP (Model Context Protocol) server automatically injected into every CherryClaw (Soul Mode) session. It provides three self-management tools for the agent: `cron` (task scheduling), `notify` (notifications), and `config` (agent/channel self-configuration). Skill and memory management used to live here too but were extracted into their own standalone MCP servers (see [Related servers](#related-servers-formerly-claw-tools)).
## Architecture
@@ -10,7 +10,7 @@ CherryClawService.invoke()
→ Inject as in-memory MCP server:
_internalMcpServers = { claw: { type: 'inmem', instance: clawServer.mcpServer } }
→ ClaudeCodeService merges into SDK options.mcpServers
→ SDK auto-discovers tools: mcp__claw__cron, mcp__claw__notify, mcp__claw__skills, mcp__claw__memory
→ SDK auto-discovers tools: mcp__claw__cron, mcp__claw__notify, mcp__claw__config
```
ClawServer uses the `@modelcontextprotocol/sdk` `McpServer` class, running in memory mode (no HTTP transport). A new instance is created per CherryClaw session invocation, bound to the current agent's ID.
@@ -83,90 +83,44 @@ Returns an informational message (not an error) if no notification channels are
---
## skills Tool
## config Tool
Manage Claude skills in the agent workspace. Supports searching from the marketplace, installing, uninstalling, and listing installed skills.
Inspect and manage the agent's own configuration — identity, model, and IM channel connections — and drive the onboarding ("bootstrap") ritual.
### Parameters
| Parameter | Type | Required | Description |
|---|---|---|---|
| `action` | string | Yes | One of the actions below |
| `type` | string | For `add_channel` | Channel adapter type: `telegram` / `feishu` / `qq` / `wechat` / `discord` / `slack` |
| `name` | string | For `rename` / `add_channel` | New display name (`rename`) or human-readable channel name (`add_channel`) |
| `channel_id` | string | For `update_channel` / `remove_channel` / `reconnect_channel` | Target channel id |
| `config` | object | For `add_channel` | Adapter-specific configuration (optional for `update_channel`) |
| `enabled` | boolean | No | Enable/disable the channel (defaults to true) |
### Actions
#### `search` — Search Skills
| Parameter | Type | Required | Description |
|---|---|---|---|
| `query` | string | Yes | Search keywords |
Queries the public marketplace API (`claude-plugins.dev/api/skills`), returns matching skills with `name`, `description`, `author`, `identifier` (for installation), and `installs` count. Hyphens and underscores in search terms are replaced with spaces to improve matching.
#### `install` — Install Skill
| Parameter | Type | Required | Description |
|---|---|---|---|
| `identifier` | string | Yes | Marketplace skill identifier, format `owner/repo/skill-name` |
Constructs `marketplace:skill:{identifier}` path internally, delegates to `PluginService.install()`.
#### `remove` — Uninstall Skill
| Parameter | Type | Required | Description |
|---|---|---|---|
| `name` | string | Yes | Skill folder name (from list results) |
Delegates to `PluginService.uninstall()`.
#### `list` — List Installed Skills
No parameters. Returns all installed skills for the current agent, including `name`, `folder`, and `description`.
| Action | Description |
|---|---|
| `status` | Current channels, model, and supported adapter types |
| `rename` | Change the agent's display name |
| `add_channel` / `update_channel` / `remove_channel` | Manage IM channel connections |
| `reconnect_channel` | Re-scan a QR code for a WeChat/Feishu channel (e.g. expired session or failed initial setup) |
| `complete_bootstrap` | Mark the onboarding ritual as done |
| `reset_bootstrap` | Re-run onboarding in the next session |
---
## memory Tool
## Related servers (formerly claw tools)
Manage persistent cross-session memory. This is the write interface for CherryClaw's memory system (reading is done via inline content in the system prompt).
Skill and memory management were extracted out of claw into their own standalone MCP servers:
### Design Principle
The tool description encodes the memory decision logic:
> Before writing to FACT.md, ask yourself: will this information still matter in 6 months? If not, use append instead.
### Actions
#### `update` — Update FACT.md
| Parameter | Type | Required | Description |
| Capability | Server | Tool | File |
|---|---|---|---|
| `content` | string | Yes | Complete markdown content of FACT.md |
| Skills (search / install / uninstall / list) | `skills` | `mcp__skills__skills` | `src/main/ai/mcp/servers/skills.ts` |
| Persistent memory (update / append / search) | `agent-memory` | `mcp__agent-memory__memory` | `src/main/ai/mcp/servers/workspaceMemory.ts` |
Atomic write: writes to a temp file first, then replaces via `rename`. Ensures no file corruption from mid-write crashes.
File path supports case-insensitive matching. The `memory/` directory is auto-created if it doesn't exist.
**Note**: This is a full overwrite, not an incremental edit. The agent needs to read existing content first, modify it, then write back the complete content.
#### `append` — Append Log Entry
| Parameter | Type | Required | Description |
|---|---|---|---|
| `text` | string | Yes | Log entry text |
| `tags` | string[] | No | Tag list |
Appends a JSON line to `memory/JOURNAL.jsonl`:
```json
{"ts":"2026-03-10T12:00:00.000Z","tags":["deploy","production"],"text":"Deployed v2.1 to production"}
```
Timestamp is auto-generated. Suitable for one-off events, completed tasks, session summaries, and other short-term information.
#### `search` — Search Logs
| Parameter | Type | Required | Description |
|---|---|---|---|
| `query` | string | No | Case-insensitive substring match |
| `tag` | string | No | Filter by tag |
| `limit` | integer | No | Max results (default 20) |
Returns matching log entries in reverse chronological order. `query` and `tag` can be combined.
> The CherryClaw system prompt and the workspace bootstrap reference memory as `mcp__agent-memory__memory` — **not** `mcp__claw__memory`.
---
@@ -178,6 +132,6 @@ All tool calls execute within an internal try-catch. On error, returns an `{ isE
| File | Description |
|---|---|
| `src/main/mcpServers/claw.ts` | ClawServer complete implementation (4 tools + helpers) |
| `src/main/mcpServers/__tests__/claw.test.ts` | 37 unit tests |
| `src/main/services/agents/services/cherryclaw/index.ts` | MCP server injection logic |
| `src/main/ai/mcp/servers/claw.ts` | ClawServer implementation (`cron` / `notify` / `config` + helpers) |
| `src/main/ai/mcp/servers/__tests__/claw.test.ts` | Unit tests |
| `src/main/ai/runtime/claudeCode/settingsBuilder.ts` | `buildMcpServers` — injects the claw server in Soul Mode |

View File

@@ -121,6 +121,6 @@ Both tables are associated with the agents table via foreign key cascades.
| `src/main/services/agents/services/AgentServiceRegistry.ts` | Agent service registry |
| `src/main/services/agents/services/TaskService.ts` | Task CRUD + scheduling calculation |
| `src/main/services/agents/services/SchedulerService.ts` | Polling scheduler |
| `src/main/mcpServers/claw.ts` | Claw MCP server |
| `src/main/ai/mcp/servers/claw.ts` | Claw MCP server |
| `src/main/services/agents/services/channels/` | Channel abstraction layer |
| `src/main/services/agents/database/schema/tasks.schema.ts` | Task table schema |

View File

@@ -18,7 +18,7 @@ Value type is inferred from the schema. Hooks pin the cache entry (refcounted)
import { useCache, useSharedCache, usePersistCache } from '@data/hooks/useCache'
// Memory — single renderer
const [generating, setGenerating] = useCache('chat.generating', false)
const [generating, setGenerating] = useCache('chat.web_search.searching', false)
// Shared — all windows
const [activeSearches, setActive] = useSharedCache('chat.web_search.active_searches')
@@ -42,12 +42,12 @@ import { cacheService } from '@data/CacheService'
```typescript
// Schema keys (Fixed or Template) — type-inferred
cacheService.set('chat.generating', true)
cacheService.set('chat.generating', true, 30_000) // with TTL (ms)
cacheService.get('chat.generating') // boolean
cacheService.has('chat.generating')
cacheService.hasTTL('chat.generating')
cacheService.delete('chat.generating')
cacheService.set('chat.web_search.searching', true)
cacheService.set('chat.web_search.searching', true, 30_000) // with TTL (ms)
cacheService.get('chat.web_search.searching') // boolean
cacheService.has('chat.web_search.searching')
cacheService.hasTTL('chat.web_search.searching')
cacheService.delete('chat.web_search.searching')
// Casual (Memory tier only, no schema match allowed)
cacheService.setCasual<TopicCache>(`topic:${id}`, data, 30_000)

View File

@@ -70,10 +70,10 @@ files:
- "!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters,Makefile,*.Makefile}" # filter .node build files
- "resources/**/*" # include all files in resources, and unpack them from asar for runtime access
asarUnpack:
- out/proxy/**
- resources/**
- "**/*.{metal,exp,lib}"
- "node_modules/@img/sharp-*/**"
- "node_modules/@anthropic-ai/claude-agent-sdk-*/**" # native CLI binary must live on disk to be spawned
extraResources:
- from: "migrations/sqlite-drizzle"
to: "migrations/sqlite-drizzle"

View File

@@ -8,7 +8,6 @@ import { visualizer } from 'rollup-plugin-visualizer'
// assert not supported by biome
// import pkg from './package.json' assert { type: 'json' }
import pkg from './package.json'
import { buildProxyBootstrapPlugin } from './scripts/buildProxyBootstrapPlugin'
const visualizerPlugin = (type: 'renderer' | 'main') => {
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
@@ -23,14 +22,7 @@ const mainExternalDependencies = Object.keys(pkg.dependencies).filter(
export default defineConfig({
main: {
plugins: [
...visualizerPlugin('main'),
buildProxyBootstrapPlugin({
dependencies: Object.keys(pkg.dependencies),
isProd,
rootDir: __dirname
})
],
plugins: [...visualizerPlugin('main')],
resolve: {
alias: {
'@main': resolve('src/main'),
@@ -41,6 +33,10 @@ export default defineConfig({
'@logger': resolve('src/main/core/logger/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src'),
'@vectorstores/libsql': resolve('packages/vectorstores/libsql/src/index.ts'),
'@cherrystudio/provider-registry/node': resolve('packages/provider-registry/src/registry-loader'),
'@cherrystudio/provider-registry': resolve('packages/provider-registry/src'),

View File

@@ -0,0 +1,95 @@
CREATE TABLE `agent_workspace` (
`id` text PRIMARY KEY NOT NULL,
`name` text NOT NULL,
`path` text NOT NULL,
`order_key` text NOT NULL,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL
);
--> statement-breakpoint
CREATE UNIQUE INDEX `agent_workspace_path_unique_idx` ON `agent_workspace` (`path`);--> statement-breakpoint
CREATE INDEX `agent_workspace_order_key_idx` ON `agent_workspace` (`order_key`);--> statement-breakpoint
DROP TABLE `agent_task_run_log`;--> statement-breakpoint
DROP TABLE `agent_task`;--> statement-breakpoint
PRAGMA foreign_keys=OFF;--> statement-breakpoint
CREATE TABLE `__new_agent_channel_task` (
`channel_id` text NOT NULL,
`task_id` text NOT NULL,
PRIMARY KEY(`channel_id`, `task_id`),
FOREIGN KEY (`channel_id`) REFERENCES `agent_channel`(`id`) ON UPDATE no action ON DELETE cascade,
FOREIGN KEY (`task_id`) REFERENCES `job_schedule`(`id`) ON UPDATE no action ON DELETE cascade
);
--> statement-breakpoint
INSERT INTO `__new_agent_channel_task`("channel_id", "task_id") SELECT "channel_id", "task_id" FROM `agent_channel_task`;--> statement-breakpoint
DROP TABLE `agent_channel_task`;--> statement-breakpoint
ALTER TABLE `__new_agent_channel_task` RENAME TO `agent_channel_task`;--> statement-breakpoint
PRAGMA foreign_keys=ON;--> statement-breakpoint
CREATE INDEX `agent_channel_task_channel_id_idx` ON `agent_channel_task` (`channel_id`);--> statement-breakpoint
CREATE INDEX `agent_channel_task_task_id_idx` ON `agent_channel_task` (`task_id`);--> statement-breakpoint
CREATE TABLE `__new_agent_session` (
`id` text PRIMARY KEY NOT NULL,
`agent_id` text,
`name` text NOT NULL,
`description` text DEFAULT '' NOT NULL,
`workspace_id` text,
`order_key` text NOT NULL,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL,
FOREIGN KEY (`agent_id`) REFERENCES `agent`(`id`) ON UPDATE no action ON DELETE set null,
FOREIGN KEY (`workspace_id`) REFERENCES `agent_workspace`(`id`) ON UPDATE no action ON DELETE set null
);
--> statement-breakpoint
DROP TABLE `agent_session`;--> statement-breakpoint
ALTER TABLE `__new_agent_session` RENAME TO `agent_session`;--> statement-breakpoint
CREATE INDEX `agent_session_order_key_idx` ON `agent_session` (`order_key`);--> statement-breakpoint
CREATE TABLE `__new_agent` (
`id` text PRIMARY KEY NOT NULL,
`type` text NOT NULL,
`name` text NOT NULL,
`description` text DEFAULT '' NOT NULL,
`instructions` text NOT NULL,
`model` text,
`plan_model` text,
`small_model` text,
`mcps` text DEFAULT '[]' NOT NULL,
`allowed_tools` text DEFAULT '[]' NOT NULL,
`configuration` text DEFAULT '{}' NOT NULL,
`order_key` text NOT NULL,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL,
`deleted_at` integer,
FOREIGN KEY (`model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
FOREIGN KEY (`plan_model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
FOREIGN KEY (`small_model`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null
);
--> statement-breakpoint
DROP TABLE `agent`;--> statement-breakpoint
ALTER TABLE `__new_agent` RENAME TO `agent`;--> statement-breakpoint
CREATE INDEX `agent_name_idx` ON `agent` (`name`);--> statement-breakpoint
CREATE INDEX `agent_type_idx` ON `agent` (`type`);--> statement-breakpoint
CREATE INDEX `agent_order_key_idx` ON `agent` (`order_key`);--> statement-breakpoint
CREATE TABLE `__new_agent_session_message` (
`id` text PRIMARY KEY NOT NULL,
`session_id` text NOT NULL,
`role` text NOT NULL,
`data` text NOT NULL,
`searchable_text` text DEFAULT '' NOT NULL,
`status` text NOT NULL,
`model_id` text,
`model_snapshot` text,
`trace_id` text,
`stats` text,
`runtime_resume_token` text,
`created_at` integer NOT NULL,
`updated_at` integer NOT NULL,
FOREIGN KEY (`session_id`) REFERENCES `agent_session`(`id`) ON UPDATE no action ON DELETE cascade,
FOREIGN KEY (`model_id`) REFERENCES `user_model`(`id`) ON UPDATE no action ON DELETE set null,
CONSTRAINT "agent_session_message_role_check" CHECK("__new_agent_session_message"."role" IN ('user', 'assistant', 'system')),
CONSTRAINT "agent_session_message_status_check" CHECK("__new_agent_session_message"."status" IN ('pending', 'success', 'error', 'paused'))
);
--> statement-breakpoint
DROP TABLE `agent_session_message`;--> statement-breakpoint
ALTER TABLE `__new_agent_session_message` RENAME TO `agent_session_message`;--> statement-breakpoint
CREATE INDEX `agent_session_message_session_created_id_idx` ON `agent_session_message` (`session_id`,`created_at`,`id`);--> statement-breakpoint
ALTER TABLE `assistant` ADD `order_key` text DEFAULT '' NOT NULL;--> statement-breakpoint
CREATE INDEX `assistant_order_key_idx` ON `assistant` (`order_key`);

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,13 @@
"when": 1780386366672,
"tag": "0002_clever_skin",
"breakpoints": true
},
{
"idx": 3,
"version": "6",
"when": 1780586439674,
"tag": "0003_burly_tomorrow_man",
"breakpoints": true
}
]
}

View File

@@ -90,7 +90,8 @@
"ci": "pnpm ci:basic-check && pnpm ci:test-check"
},
"dependencies": {
"@anthropic-ai/claude-agent-sdk": "0.2.112",
"@anthropic-ai/claude-agent-sdk": "0.3.145",
"@cherrystudio/ripgrep": "^1.0.0",
"@expo/sudo-prompt": "^9.3.2",
"@larksuiteoapi/node-sdk": "^1.59.0",
"@libsql/client": "^0.15.15",
@@ -100,7 +101,6 @@
"@vectorstores/core": "^0.1.8",
"@vectorstores/libsql": "workspace:*",
"@vectorstores/readers": "^0.1.8",
"cron-parser": "^5.0.8",
"express": "5.1.0",
"font-list": "2.0.0",
"graceful-fs": "4.2.11",
@@ -128,6 +128,7 @@
"@ai-sdk/azure": "^3.0.54",
"@ai-sdk/cerebras": "^2.0.45",
"@ai-sdk/cohere": "^3.0.30",
"@ai-sdk/devtools": "^0.0.17",
"@ai-sdk/gateway": "^3.0.104",
"@ai-sdk/google": "3.0.64",
"@ai-sdk/google-vertex": "^4.0.112",
@@ -139,13 +140,14 @@
"@ai-sdk/perplexity": "^3.0.29",
"@ai-sdk/provider": "^3.0.8",
"@ai-sdk/provider-utils": "^4.0.23",
"@ai-sdk/react": "^3.0.147",
"@ai-sdk/test-server": "^1.0.3",
"@ai-sdk/togetherai": "^2.0.45",
"@ai-sdk/xai": "^3.0.83",
"@ant-design/cssinjs": "1.23.0",
"@ant-design/icons": "5.6.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/sdk": "^0.81.0",
"@aws-sdk/client-s3": "^3.998.0",
"@biomejs/biome": "2.2.4",
"@changesets/changelog-github": "^0.5.2",
@@ -221,7 +223,7 @@
"@tailwindcss/vite": "^4.1.13",
"@tanstack/react-query": "^5.85.5",
"@tanstack/react-router": "^1.139.3",
"@tanstack/react-virtual": "^3.13.12",
"@tanstack/react-virtual": "^3.13.24",
"@tanstack/router-plugin": "^1.139.3",
"@testing-library/dom": "^10.4.0",
"@testing-library/jest-dom": "^6.6.3",
@@ -409,7 +411,6 @@
"react-hook-form": "^7.55.0",
"react-hotkeys-hook": "^4.6.1",
"react-i18next": "^14.1.2",
"react-infinite-scroll-component": "^6.1.0",
"react-json-view": "^1.21.3",
"react-markdown": "^10.1.0",
"react-player": "^3.3.1",
@@ -544,6 +545,14 @@
},
"packageManager": "pnpm@10.27.0",
"optionalDependencies": {
"@anthropic-ai/claude-agent-sdk-darwin-arm64": "0.3.145",
"@anthropic-ai/claude-agent-sdk-darwin-x64": "0.3.145",
"@anthropic-ai/claude-agent-sdk-linux-arm64": "0.3.145",
"@anthropic-ai/claude-agent-sdk-linux-arm64-musl": "0.3.145",
"@anthropic-ai/claude-agent-sdk-linux-x64": "0.3.145",
"@anthropic-ai/claude-agent-sdk-linux-x64-musl": "0.3.145",
"@anthropic-ai/claude-agent-sdk-win32-arm64": "0.3.145",
"@anthropic-ai/claude-agent-sdk-win32-x64": "0.3.145",
"@img/sharp-darwin-arm64": "0.34.5",
"@img/sharp-darwin-x64": "0.34.5",
"@img/sharp-libvips-darwin-arm64": "1.2.4",

View File

@@ -227,32 +227,6 @@ const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
])
```
#### promptToolUsePlugin - 提示工具使用插件
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
```typescript
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
// 对于不支持 function call 的模型
const executor = AiCore.create(
'providerId',
{
apiKey: 'your-key',
baseURL: 'https://your-model-endpoint'
},
[
createPromptToolUsePlugin({
enabled: true,
// 可选:自定义系统提示符构建
buildSystemPrompt: (userPrompt, tools) => {
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
}
})
]
)
```
### 自定义插件
创建自定义插件非常简单:

View File

@@ -45,7 +45,7 @@
"@ai-sdk/xai": "^3.0.83",
"@cherrystudio/ai-sdk-provider": "workspace:*",
"@openrouter/ai-sdk-provider": "^2.3.3",
"lru-cache": "^11.2.4",
"quick-lru": "^5.1.1",
"zod": "^4.1.5"
},
"devDependencies": {

View File

@@ -6,6 +6,5 @@
export type { ModelConfig as ModelConfigType } from './models/types'
// 执行管理
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
export type { RuntimeConfig } from './runtime/types'

View File

@@ -1,4 +1,2 @@
export * from './providerToolPlugin'
export * from './toolUsePlugin/promptToolUsePlugin'
export * from './toolUsePlugin/type'
export * from './webSearchPlugin'

View File

@@ -1,251 +0,0 @@
/**
* 流事件管理器
*
* 负责处理 AI SDK 流事件的发送和管理
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
*/
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
import type { EmbeddingModelUsage, ImageModelUsage, LanguageModelUsage, ModelMessage } from 'ai'
import type { AiSdkUsage } from '../../../providers/types'
import type { AiRequestContext, StreamTextParams, StreamTextResult } from '../../types'
import type { StreamController } from './ToolExecutor'
/**
* 类型守卫:检查对象是否是有效的流结果(包含 ReadableStream 类型的 fullStream
*/
function hasFullStream(obj: unknown): obj is StreamTextResult & { fullStream: ReadableStream } {
return typeof obj === 'object' && obj !== null && 'fullStream' in obj && obj.fullStream instanceof ReadableStream
}
/**
* 类型守卫:检查 usage 是否是 LanguageModelUsage
* LanguageModelUsage 包含 totalTokens, inputTokens, outputTokens 等字段
*/
function isLanguageModelUsage(usage: unknown): usage is LanguageModelUsage {
return (
typeof usage === 'object' &&
usage !== null &&
('totalTokens' in usage || 'inputTokens' in usage || 'outputTokens' in usage)
)
}
/**
* 类型守卫:检查 usage 是否是 ImageModelUsage
* ImageModelUsage 包含 inputTokens, outputTokens, totalTokens 字段
* but lacks inputTokenDetails/outputTokenDetails which are present in LanguageModelUsage
*/
function isImageModelUsage(usage: unknown): usage is ImageModelUsage {
return (
typeof usage === 'object' &&
usage !== null &&
'inputTokens' in usage &&
'outputTokens' in usage &&
!('inputTokenDetails' in usage) &&
!('outputTokenDetails' in usage)
)
}
/**
* 类型守卫:检查 usage 是否是 EmbeddingModelUsage
* EmbeddingModelUsage 只包含 tokens 字段
*/
function isEmbeddingModelUsage(usage: unknown): usage is EmbeddingModelUsage {
return (
typeof usage === 'object' &&
usage !== null &&
'tokens' in usage &&
// 确保只有 tokens 字段(没有 inputTokens, outputTokens 等)
!('inputTokens' in usage) &&
!('outputTokens' in usage)
)
}
/**
* 流事件管理器类
*/
export class StreamEventManager {
/**
* 发送工具调用步骤开始事件
*/
sendStepStartEvent(controller: StreamController): void {
controller.enqueue({
type: 'start-step',
request: {},
warnings: []
})
}
/**
* 发送步骤完成事件
*/
sendStepFinishEvent(
controller: StreamController,
chunk: {
usage?: Partial<AiSdkUsage>
response?: { id: string; [key: string]: unknown }
providerMetadata?: SharedV3ProviderMetadata
},
context: AiRequestContext,
finishReason: string = 'stop'
): void {
// 累加当前步骤的 usage
if (chunk.usage && context.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
}
controller.enqueue({
type: 'finish-step',
finishReason,
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
}
/**
* 处理递归调用并将结果流接入当前流
*/
async handleRecursiveCall<TParams extends StreamTextParams>(
controller: StreamController,
recursiveParams: Partial<TParams>,
context: AiRequestContext<TParams, StreamTextResult>
): Promise<void> {
// try {
// 重置工具执行状态,准备处理新的步骤
context.hasExecutedToolsInCurrentStep = false
const recursiveResult = await context.recursiveCall(recursiveParams)
if (hasFullStream(recursiveResult)) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
// } catch (error) {
// this.handleRecursiveCallError(controller, error, stepId)
// }
}
/**
* 将递归流的数据传递到当前流
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
if (value.type === 'start') {
continue
}
if (value.type === 'finish') {
break
}
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
}
/**
* 构建递归调用的参数
*/
buildRecursiveParams<TParams extends StreamTextParams>(
context: AiRequestContext<TParams, StreamTextResult>,
textBuffer: string,
toolResultsText: string,
tools: Record<string, unknown>
): Partial<TParams> {
const params = context.originalParams
// 构建新的对话消息
const newMessages: ModelMessage[] = [
...(params.messages || []),
// 只有当 textBuffer 有内容时才添加 assistant 消息,避免空消息导致 API 错误
...(textBuffer ? [{ role: 'assistant' as const, content: textBuffer }] : []),
{
role: 'user',
content: toolResultsText
}
]
// 递归调用,继续对话,重新传递 tools
const recursiveParams = {
...params,
messages: newMessages,
tools: tools
} as Partial<TParams>
return recursiveParams
}
/**
* 累加 usage 数据
*
* 使用类型守卫来处理不同类型的 usageLanguageModelUsage, ImageModelUsage, EmbeddingModelUsage
* - LanguageModelUsage: inputTokens, outputTokens, totalTokens
* - ImageModelUsage: inputTokens, outputTokens, totalTokens
* - EmbeddingModelUsage: tokens
*/
accumulateUsage(target: Partial<AiSdkUsage>, source: Partial<AiSdkUsage>): void {
if (!target || !source) return
if (isLanguageModelUsage(target) && isLanguageModelUsage(source)) {
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
// Accumulate inputTokenDetails (cacheReadTokens, cacheWriteTokens, noCacheTokens)
if (source.inputTokenDetails) {
if (!target.inputTokenDetails) {
target.inputTokenDetails = {
noCacheTokens: undefined,
cacheReadTokens: undefined,
cacheWriteTokens: undefined
}
}
target.inputTokenDetails.cacheReadTokens =
(target.inputTokenDetails.cacheReadTokens || 0) + (source.inputTokenDetails.cacheReadTokens || 0)
target.inputTokenDetails.cacheWriteTokens =
(target.inputTokenDetails.cacheWriteTokens || 0) + (source.inputTokenDetails.cacheWriteTokens || 0)
target.inputTokenDetails.noCacheTokens =
(target.inputTokenDetails.noCacheTokens || 0) + (source.inputTokenDetails.noCacheTokens || 0)
}
// Accumulate outputTokenDetails (reasoningTokens, textTokens)
if (source.outputTokenDetails) {
if (!target.outputTokenDetails) {
target.outputTokenDetails = { textTokens: undefined, reasoningTokens: undefined }
}
target.outputTokenDetails.reasoningTokens =
(target.outputTokenDetails.reasoningTokens || 0) + (source.outputTokenDetails.reasoningTokens || 0)
target.outputTokenDetails.textTokens =
(target.outputTokenDetails.textTokens || 0) + (source.outputTokenDetails.textTokens || 0)
}
return
}
if (isImageModelUsage(target) && isImageModelUsage(source)) {
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
return
}
if (isEmbeddingModelUsage(target) && isEmbeddingModelUsage(source)) {
target.tokens = (target.tokens || 0) + (source.tokens || 0)
return
}
// ⚠️ 未知类型或类型不匹配,不进行累加
console.warn('[StreamEventManager] Unable to accumulate usage - type mismatch or unknown type', {
target,
source
})
}
}

View File

@@ -1,152 +0,0 @@
/**
* 工具执行器
*
* 负责工具的执行、结果格式化和相关事件发送
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
*/
import type { ToolSet, TypedToolError } from 'ai'
import type { ToolUseResult } from './type'
/**
* 工具执行结果
*/
export interface ExecutedResult {
toolCallId: string
toolName: string
result: unknown
isError?: boolean
}
/**
* 流控制器类型(从 AI SDK 提取)
* Generic type parameter allows for type-safe chunk enqueuing
*/
export interface StreamController<TChunk = unknown> {
enqueue(chunk: TChunk): void
}
/**
* 工具执行器类
*/
export class ToolExecutor {
/**
* 执行多个工具调用
*/
async executeTools(
toolUses: ToolUseResult[],
tools: ToolSet,
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments
})
const result = await tool.execute(toolUse.arguments, {
toolCallId: toolUse.id,
messages: [],
abortSignal: new AbortController().signal
})
// 发送 tool-result 事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
output: result
})
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result,
isError: false
})
} catch (error) {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 处理错误情况
const errorResult = this.handleToolError(toolUse, error, controller)
executedResults.push(errorResult)
}
}
return executedResults
}
/**
* 格式化工具结果为 Cherry Studio 标准格式
*/
formatToolResults(executedResults: ExecutedResult[]): string {
return executedResults
.map((tr) => {
if (!tr.isError) {
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
} else {
const error = tr.result || 'Unknown error'
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
}
})
.join('\n\n')
}
/**
* 发送工具调用开始相关事件
*/
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// // 发送 tool-input-start 事件
// controller.enqueue({
// type: 'tool-input-start',
// id: toolUse.id,
// toolName: toolUse.toolName
// })
// }
/**
* 处理工具执行错误
*/
private handleToolError<T extends ToolSet>(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
): ExecutedResult {
// 使用 AI SDK 标准错误格式
const toolError: TypedToolError<T> = {
type: 'tool-error',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
error
}
controller.enqueue(toolError)
// 发送标准错误事件
// controller.enqueue({
// type: 'tool-error',
// toolCallId: toolUse.id,
// error: error instanceof Error ? error.message : String(error),
// input: toolUse.arguments
// })
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error,
isError: true
}
}
}

View File

@@ -1,566 +0,0 @@
import type { SharedV3ProviderMetadata } from '@ai-sdk/provider'
import { createMockContext, createMockTool } from '@test-utils'
import type {
EmbeddingModelUsage,
ImageModelUsage,
LanguageModelUsage as AiSdkUsage,
LanguageModelUsage,
TextStreamPart,
ToolSet
} from 'ai'
import { simulateReadableStream } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { StreamEventManager } from '../StreamEventManager'
import type { StreamController } from '../ToolExecutor'
/**
* Type alias for empty toolset (no tools)
* Using Record<string, never> ensures type safety for tests without tools
*/
type EmptyToolSet = Record<string, never>
/**
* Mock StreamController for testing
* Provides type-safe enqueue function that accepts TextStreamPart chunks
*/
interface MockStreamController<TOOLS extends ToolSet = EmptyToolSet> extends StreamController {
enqueue: ReturnType<typeof vi.fn<(chunk: TextStreamPart<TOOLS>) => void>>
}
/**
* Create a type-safe mock stream controller
*/
function createMockStreamController<TOOLS extends ToolSet = EmptyToolSet>(): MockStreamController<TOOLS> {
return {
enqueue: vi.fn()
}
}
/**
* Type for chunk data in finish-step events
*/
interface FinishStepChunk {
usage?: Partial<AiSdkUsage>
response?: { id: string; [key: string]: unknown }
providerMetadata?: SharedV3ProviderMetadata
}
describe('StreamEventManager', () => {
let manager: StreamEventManager
beforeEach(() => {
manager = new StreamEventManager()
})
describe('accumulateUsage', () => {
describe('LanguageModelUsage', () => {
it('should accumulate language model usage correctly', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(30)
expect(target.totalTokens).toBe(45)
})
it('should handle undefined values in target', () => {
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(10)
expect(target.totalTokens).toBe(15)
})
it('should handle undefined values in source', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
const source: Partial<LanguageModelUsage> = { inputTokens: 5 }
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(15)
expect(target.outputTokens).toBe(20)
expect(target.totalTokens).toBe(30)
})
it('should handle zero values correctly', () => {
const target: Partial<LanguageModelUsage> = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0
}
const source: Partial<LanguageModelUsage> = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(5)
expect(target.outputTokens).toBe(10)
expect(target.totalTokens).toBe(15)
})
})
describe('ImageModelUsage', () => {
it('should accumulate image model usage correctly', () => {
const target: Partial<ImageModelUsage> = {
inputTokens: 100,
outputTokens: 50,
totalTokens: 150
}
const source: Partial<ImageModelUsage> = {
inputTokens: 50,
outputTokens: 25,
totalTokens: 75
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(150)
expect(target.outputTokens).toBe(75)
expect(target.totalTokens).toBe(225)
})
it('should handle undefined values', () => {
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
const source: Partial<ImageModelUsage> = {
outputTokens: 50,
totalTokens: 50
}
manager.accumulateUsage(target, source)
expect(target.inputTokens).toBe(100)
expect(target.outputTokens).toBe(50)
expect(target.totalTokens).toBe(50)
})
})
describe('EmbeddingModelUsage', () => {
it('should accumulate embedding model usage correctly', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 100 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(150)
})
it('should handle zero to non-zero accumulation', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(50)
})
it('should handle zero values', () => {
const target: Partial<EmbeddingModelUsage> = { tokens: 0 }
const source: Partial<EmbeddingModelUsage> = { tokens: 100 }
manager.accumulateUsage(target, source)
expect(target.tokens).toBe(100)
})
})
describe('Type Guard Validation', () => {
it('should warn on type mismatch between LanguageModelUsage and EmbeddingModelUsage', () => {
const warnSpy = vi.spyOn(console, 'warn')
const target: Partial<LanguageModelUsage> = { inputTokens: 10 }
const source: Partial<EmbeddingModelUsage> = { tokens: 5 }
manager.accumulateUsage(target, source)
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining('Unable to accumulate usage'),
expect.objectContaining({
target,
source
})
)
warnSpy.mockRestore()
})
it('should warn on type mismatch between ImageModelUsage and EmbeddingModelUsage', () => {
const warnSpy = vi.spyOn(console, 'warn')
const target: Partial<ImageModelUsage> = { inputTokens: 100 }
const source: Partial<EmbeddingModelUsage> = { tokens: 50 }
manager.accumulateUsage(target, source)
expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('Unable to accumulate usage'), expect.any(Object))
warnSpy.mockRestore()
})
})
})
describe('buildRecursiveParams', () => {
it('should include textBuffer in assistant message when not empty', () => {
const context = createMockContext()
const textBuffer = 'test response'
const toolResultsText = '<tool_use_result>...</tool_use_result>'
const tools = {
test_tool: createMockTool('test_tool')
}
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
expect(params.messages).toHaveLength(3)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
expect(params.messages?.[1]).toEqual({ role: 'assistant', content: textBuffer })
expect(params.messages?.[2]).toEqual({
role: 'user',
content: toolResultsText
})
expect(params.tools).toBe(tools)
})
it('should skip empty textBuffer in messages', () => {
const context = createMockContext()
const textBuffer = ''
const toolResultsText = '<tool_use_result>...</tool_use_result>'
const tools = {}
const params = manager.buildRecursiveParams(context, textBuffer, toolResultsText, tools)
// Should only have original user message and new user message with tool results
expect(params.messages).toHaveLength(2)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'Test message' })
expect(params.messages?.[1]).toEqual({
role: 'user',
content: toolResultsText
})
const assistantMessages = params.messages?.filter((m) => m.role === 'assistant')
expect(assistantMessages).toHaveLength(0)
})
it('should preserve all original messages', () => {
const context = createMockContext({
originalParams: {
messages: [
{ role: 'user', content: 'First message' },
{ role: 'assistant', content: 'First response' },
{ role: 'user', content: 'Second message' }
]
}
})
const params = manager.buildRecursiveParams(context, 'New response', 'Tool results', {})
expect(params.messages).toHaveLength(5)
expect(params.messages?.[0]).toEqual({ role: 'user', content: 'First message' })
expect(params.messages?.[1]).toEqual({
role: 'assistant',
content: 'First response'
})
expect(params.messages?.[2]).toEqual({ role: 'user', content: 'Second message' })
expect(params.messages?.[3]).toEqual({ role: 'assistant', content: 'New response' })
expect(params.messages?.[4]).toEqual({ role: 'user', content: 'Tool results' })
})
it('should pass through tools parameter', () => {
const context = createMockContext()
const tools = {
tool1: createMockTool('tool1'),
tool2: createMockTool('tool2')
}
const params = manager.buildRecursiveParams(context, 'response', 'results', tools)
expect(params.tools).toBe(tools)
expect(Object.keys(params.tools!)).toHaveLength(2)
})
})
describe('sendStepStartEvent', () => {
it('should enqueue start-step event with correct structure', () => {
const controller = createMockStreamController()
manager.sendStepStartEvent(controller)
expect(controller.enqueue).toHaveBeenCalledWith({
type: 'start-step',
request: {},
warnings: []
})
})
})
describe('sendStepFinishEvent', () => {
it('should enqueue finish-step event with provided finishReason', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
},
response: { id: 'test-response' },
providerMetadata: { 'test-provider': {} }
}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0
}
})
manager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
expect(controller.enqueue).toHaveBeenCalledWith({
type: 'finish-step',
finishReason: 'tool-calls',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
})
it('should accumulate usage when provided', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
})
manager.sendStepFinishEvent(controller, chunk, context)
// Verify accumulation happened
expect(context.accumulatedUsage.inputTokens).toBe(15)
expect(context.accumulatedUsage.outputTokens).toBe(30)
expect(context.accumulatedUsage.totalTokens).toBe(45)
})
it('should handle missing usage gracefully', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {}
const context = createMockContext({
accumulatedUsage: {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15
}
})
expect(() => manager.sendStepFinishEvent(controller, chunk, context)).not.toThrow()
// Verify accumulation did not change
expect(context.accumulatedUsage.inputTokens).toBe(5)
expect(context.accumulatedUsage.outputTokens).toBe(10)
expect(context.accumulatedUsage.totalTokens).toBe(15)
})
it('should use default finishReason of "stop" when not provided', () => {
const controller = createMockStreamController()
const chunk: FinishStepChunk = {}
const context = createMockContext()
manager.sendStepFinishEvent(controller, chunk, context)
expect(controller.enqueue).toHaveBeenCalledWith(
expect.objectContaining({
finishReason: 'stop'
})
)
})
})
describe('handleRecursiveCall', () => {
it('should reset hasExecutedToolsInCurrentStep flag', async () => {
const controller = createMockStreamController()
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
chunks: [
{
type: 'text-delta',
id: 'test-id',
text: 'test'
} as TextStreamPart<EmptyToolSet>
],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
fullStream: mockStream
})
})
const params = { messages: [] }
await manager.handleRecursiveCall(controller, params, context)
expect(context.hasExecutedToolsInCurrentStep).toBe(false)
expect(context.recursiveCall).toHaveBeenCalledWith(params)
})
it('should pipe recursive stream to controller', async () => {
const enqueuedChunks: TextStreamPart<EmptyToolSet>[] = []
const controller = createMockStreamController()
controller.enqueue.mockImplementation((chunk: TextStreamPart<EmptyToolSet>) => {
enqueuedChunks.push(chunk)
})
const mockChunks: TextStreamPart<EmptyToolSet>[] = [
{ type: 'start' as const },
{ type: 'start-step' as const, request: {}, warnings: [] },
{ type: 'text-delta' as const, id: 'chunk-1', text: 'recursive' },
{ type: 'text-delta' as const, id: 'chunk-2', text: ' response' },
{
type: 'finish-step' as const,
finishReason: 'stop',
rawFinishReason: 'stop',
response: {
id: 'test-response-id',
timestamp: new Date(),
modelId: 'test-model'
},
usage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
},
providerMetadata: undefined
},
{
type: 'finish' as const,
finishReason: 'stop',
rawFinishReason: 'stop',
totalUsage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
}
}
]
const mockStream = simulateReadableStream<TextStreamPart<EmptyToolSet>>({
chunks: mockChunks,
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
fullStream: mockStream
})
})
await manager.handleRecursiveCall(controller, {}, context)
// Should skip 'start' type and stop at 'finish' type
expect(enqueuedChunks).toHaveLength(4)
expect(enqueuedChunks[0]).toEqual({ type: 'start-step', request: {}, warnings: [] })
expect(enqueuedChunks[1]).toEqual({ type: 'text-delta', id: 'chunk-1', text: 'recursive' })
expect(enqueuedChunks[2]).toEqual({ type: 'text-delta', id: 'chunk-2', text: ' response' })
expect(enqueuedChunks[3]).toMatchObject({
type: 'finish-step',
finishReason: 'stop',
rawFinishReason: 'stop',
providerMetadata: undefined,
usage: {
totalTokens: 0,
inputTokens: 0,
outputTokens: 0,
inputTokenDetails: {
noCacheTokens: 0,
cacheReadTokens: 0,
cacheWriteTokens: 0
},
outputTokenDetails: {
textTokens: 0,
reasoningTokens: 0
}
}
})
})
it('should warn when no fullStream is found', async () => {
const warnSpy = vi.spyOn(console, 'warn')
const controller = createMockStreamController()
const context = createMockContext({
hasExecutedToolsInCurrentStep: true,
recursiveCall: vi.fn().mockResolvedValue({
// No fullStream property
someOtherProperty: 'value'
})
})
await manager.handleRecursiveCall(controller, {}, context)
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining('[MCP Prompt] No fullstream found'),
expect.any(Object)
)
warnSpy.mockRestore()
})
})
})

View File

@@ -1,559 +0,0 @@
import { createMockContext, createMockStreamParams, createMockTool, createMockToolSet } from '@test-utils'
import type { TextStreamPart, ToolSet } from 'ai'
import { simulateReadableStream } from 'ai'
import { convertReadableStreamToArray } from 'ai/test'
import { describe, expect, it, vi } from 'vitest'
import { createPromptToolUsePlugin, DEFAULT_SYSTEM_PROMPT } from '../promptToolUsePlugin'
describe('promptToolUsePlugin', () => {
describe('Factory Function', () => {
it('should return AiPlugin with correct name', () => {
const plugin = createPromptToolUsePlugin()
expect(plugin.name).toBe('built-in:prompt-tool-use')
expect(plugin.transformParams).toBeDefined()
expect(plugin.transformStream).toBeDefined()
})
it('should accept empty configuration', () => {
const plugin = createPromptToolUsePlugin({})
expect(plugin).toBeDefined()
expect(plugin.name).toBe('built-in:prompt-tool-use')
})
it('should accept custom buildSystemPrompt', () => {
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string) => userSystemPrompt)
const plugin = createPromptToolUsePlugin({
buildSystemPrompt: customBuildSystemPrompt
})
expect(plugin).toBeDefined()
})
it('should accept custom parseToolUse', () => {
const customParseToolUse = vi.fn(() => ({ results: [], content: '' }))
const plugin = createPromptToolUsePlugin({
parseToolUse: customParseToolUse
})
expect(plugin).toBeDefined()
})
it('should accept enabled flag', () => {
const pluginDisabled = createPromptToolUsePlugin({ enabled: false })
const pluginEnabled = createPromptToolUsePlugin({ enabled: true })
expect(pluginDisabled).toBeDefined()
expect(pluginEnabled).toBeDefined()
})
})
describe('transformParams', () => {
it('should separate provider and prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
provider_tool: 'provider',
prompt_tool: 'function'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
// Provider tools should remain in tools
expect(result.tools).toBeDefined()
expect(result.tools).toHaveProperty('provider_tool')
expect(result.tools).not.toHaveProperty('prompt_tool')
// Prompt tools should be moved to context.mcpTools
expect(context.mcpTools).toBeDefined()
expect(context.mcpTools).toHaveProperty('prompt_tool')
expect(context.mcpTools).not.toHaveProperty('provider_tool')
})
it('should handle only provider tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
provider_tool1: 'provider',
provider_tool2: 'provider'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.tools).toEqual(params.tools)
expect(context.mcpTools).toBeUndefined()
})
it('should handle only prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: createMockToolSet({
prompt_tool1: 'function',
prompt_tool2: 'function'
})
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.tools).toBeUndefined()
expect(context.mcpTools).toEqual(params.tools)
})
it('should build system prompt for prompt tools', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original system prompt',
tools: {
test_tool: createMockTool('test_tool', 'Test tool description')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.system).toBeDefined()
expect(typeof result.system).toBe('string')
expect(result.system).toContain('In this environment you have access to a set of tools')
expect(result.system).toContain('test_tool')
expect(result.system).toContain('Test tool description')
expect(result.system).toContain('Original system prompt')
})
it('should handle empty user system prompt', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
tools: {
test_tool: createMockTool('test_tool')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result.system).toBeDefined()
expect(result.system).toContain('In this environment you have access to a set of tools')
})
it('should skip system prompt when disabled', async () => {
const plugin = createPromptToolUsePlugin({ enabled: false })
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: {
test_tool: createMockTool('test_tool')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
expect(context.mcpTools).toBeUndefined()
})
it('should skip when no tools provided', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original'
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
})
it('should skip when tools is not an object', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: 'invalid' as any
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(result).toEqual(params)
})
it('should use custom buildSystemPrompt when provided', async () => {
const customBuildSystemPrompt = vi.fn((userSystemPrompt: string, tools: ToolSet) => {
return `Custom prompt with ${Object.keys(tools).length} tools and user prompt: ${userSystemPrompt}`
})
const plugin = createPromptToolUsePlugin({
buildSystemPrompt: customBuildSystemPrompt
})
const context = createMockContext()
const params = createMockStreamParams({
system: 'User prompt',
tools: {
tool1: createMockTool('tool1')
}
})
const result = await Promise.resolve(plugin.transformParams!(params, context))
expect(customBuildSystemPrompt).toHaveBeenCalled()
expect(result.system).toBe('Custom prompt with 1 tools and user prompt: User prompt')
})
it('should save originalParams to context', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'Original',
tools: {
test: createMockTool('test')
}
})
await Promise.resolve(plugin.transformParams!(params, context))
expect(context.originalParams).toBeDefined()
expect(context.originalParams.system).toBeDefined()
})
it('should NOT rebuild system prompt on recursive call', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
const params = createMockStreamParams({
system: 'User system prompt',
tools: {
test_tool: createMockTool('test_tool', 'A test tool')
}
})
// First call: build the system prompt with tools
const firstResult = await Promise.resolve(plugin.transformParams!(params, context))
const firstSystemPrompt = firstResult.system as string
// Verify first call includes tool definitions
expect(firstSystemPrompt).toContain('test_tool')
// Simulate recursive call: isRecursiveCall is true
context.isRecursiveCall = true
const recursiveParams = createMockStreamParams({
system: firstSystemPrompt,
tools: {
test_tool: createMockTool('test_tool', 'A test tool')
}
})
const recursiveResult = await Promise.resolve(plugin.transformParams!(recursiveParams, context))
// System prompt should NOT be rebuilt - it should remain the same
expect(recursiveResult.system).toBe(firstSystemPrompt)
// Count occurrences of tool definition to ensure no duplication
const toolOccurrences = (recursiveResult.system as string).split('test_tool').length - 1
const firstToolOccurrences = firstSystemPrompt.split('test_tool').length - 1
expect(toolOccurrences).toBe(firstToolOccurrences)
})
})
describe('transformStream', () => {
it('should return identity transform when disabled', async () => {
const plugin = createPromptToolUsePlugin({ enabled: false })
const context = createMockContext()
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
{ type: 'text-delta', text: 'Hello' },
{ type: 'text-delta', text: ' World' }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result).toEqual(inputChunks)
})
it('should return identity transform when no mcpTools in context', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
// Don't set context.mcpTools
const inputChunks: Array<{ type: 'text-delta'; text: string }> = [
{ type: 'text-delta', text: 'Hello' },
{ type: 'text-delta', text: ' World' }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result).toEqual(inputChunks)
})
it('should initialize accumulatedUsage in context', () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
plugin.transformStream!(createMockStreamParams(), context)()
expect(context.accumulatedUsage).toBeDefined()
expect(context.accumulatedUsage).toEqual({
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
})
})
it('should filter tool tags from text-delta chunks', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: 'Before ' },
{ type: 'text-delta' as const, text: '<tool_use>' },
{ type: 'text-delta' as const, text: '<name>test</name>' },
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
{ type: 'text-delta' as const, text: '</tool_use>' },
{ type: 'text-delta' as const, text: ' After' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Extract text from text-delta chunks
const textChunks = result.filter((chunk) => chunk.type === 'text-delta')
const fullText = textChunks.map((chunk) => 'text' in chunk && chunk.text).join('')
// Tool tags should be filtered out
expect(fullText).not.toContain('<tool_use>')
expect(fullText).not.toContain('</tool_use>')
expect(fullText).toContain('Before')
expect(fullText).toContain('After')
})
it('should hold text-start until non-tag content appears', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
// Only tool tags, no actual content
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: '<tool_use>' },
{ type: 'text-delta' as const, text: '<name>test</name>' },
{ type: 'text-delta' as const, text: '<arguments>{}</arguments>' },
{ type: 'text-delta' as const, text: '</tool_use>' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Should not have text-start or text-end since all content was tool tags
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(false)
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(false)
})
it('should send text-start when non-tag content appears', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{ type: 'text-start' as const },
{ type: 'text-delta' as const, text: 'Actual content' },
{ type: 'text-end' as const }
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Should have text-start, text-delta, and text-end
expect(result.some((chunk) => chunk.type === 'text-start')).toBe(true)
expect(result.some((chunk) => chunk.type === 'text-delta')).toBe(true)
expect(result.some((chunk) => chunk.type === 'text-end')).toBe(true)
})
it('should pass through non-text events', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const stepStartEvent = { type: 'start-step' as const, request: {}, warnings: [] }
const inputChunks = [stepStartEvent]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
expect(result[0]).toEqual(stepStartEvent)
})
it('should accumulate usage from finish-step events', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
const inputChunks = [
{
type: 'finish-step' as const,
finishReason: 'stop' as const,
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30
}
}
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
await convertReadableStreamToArray(inputStream.pipeThrough(transform))
// Verify usage was accumulated
expect(context.accumulatedUsage).toBeDefined()
expect(context.accumulatedUsage!.inputTokens).toBe(10)
expect(context.accumulatedUsage!.outputTokens).toBe(20)
expect(context.accumulatedUsage!.totalTokens).toBe(30)
})
it('should include accumulated usage in finish event', async () => {
const plugin = createPromptToolUsePlugin()
const context = createMockContext()
context.mcpTools = {
test: createMockTool('test')
}
// Pre-populate accumulated usage
context.accumulatedUsage = {
inputTokens: 5,
outputTokens: 10,
totalTokens: 15,
reasoningTokens: 0,
cachedInputTokens: 0
}
const inputChunks = [
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
inputTokens: 100,
outputTokens: 200,
totalTokens: 300
}
}
]
const inputStream = simulateReadableStream<TextStreamPart<ToolSet>>({
chunks: inputChunks as unknown as TextStreamPart<ToolSet>[],
initialDelayInMs: 0,
chunkDelayInMs: 0
})
const transform = plugin.transformStream!(createMockStreamParams(), context)()
const result = await convertReadableStreamToArray(inputStream.pipeThrough(transform))
const finishEvent = result.find((chunk) => chunk.type === 'finish')
expect(finishEvent).toBeDefined()
if (finishEvent && 'totalUsage' in finishEvent) {
expect(finishEvent.totalUsage).toEqual(context.accumulatedUsage)
}
})
})
describe('Type Safety', () => {
it('should have correct generic parameters for StreamTextParams and StreamTextResult', () => {
const plugin = createPromptToolUsePlugin()
// Type assertion to verify the plugin has the correct type
type PluginType = typeof plugin
const typeTest: PluginType = plugin
expect(typeTest.name).toBe('built-in:prompt-tool-use')
})
})
describe('DEFAULT_SYSTEM_PROMPT', () => {
it('should contain required sections', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Formatting')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Tool Use Rules')
expect(DEFAULT_SYSTEM_PROMPT).toContain('Response rules')
})
it('should have placeholders for dynamic content', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ TOOLS_INFO }}')
expect(DEFAULT_SYSTEM_PROMPT).toContain('{{ USER_SYSTEM_PROMPT }}')
})
it('should contain XML tag examples', () => {
expect(DEFAULT_SYSTEM_PROMPT).toContain('<tool_use>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('</tool_use>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('<name>')
expect(DEFAULT_SYSTEM_PROMPT).toContain('<arguments>')
})
})
})

View File

@@ -1,487 +0,0 @@
/**
* 内置插件MCP Prompt 模式
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
* 内置默认逻辑,支持自定义覆盖
*/
import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiPlugin, StreamTextParams, StreamTextResult } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { type TagConfig, TagExtractor } from './tagExtraction'
import { ToolExecutor } from './ToolExecutor'
import type { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 工具使用标签配置
*/
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
export const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \
You can use one or more tools per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
## Tool Use Formatting
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
<tool_use>
<name>{tool_name}</name>
<arguments>{json_arguments}</arguments>
</tool_use>
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. IMPORTANT: When writing JSON inside the <arguments> tag, any double quotes inside string values must be escaped with a backslash ("). For example:
<tool_use>
<name>search</name>
<arguments>{ "query": "browser,fetch" }</arguments>
</tool_use>
<tool_use>
<name>exec</name>
<arguments>{ "code": "const page = await CherryBrowser_fetch({ url: \\"https://example.com\\" })\nreturn page" }</arguments>
</tool_use>
The user will respond with the result of the tool use, which should be formatted as follows:
<tool_use_result>
<name>{tool_name}</name>
<result>{result}</result>
</tool_use_result>
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
For example, if the result of the tool use is an image file, you can use it in the next action like this:
<tool_use>
<name>image_transformer</name>
<arguments>{"image": "image_1.jpg"}</arguments>
</tool_use>
Always adhere to this format for the tool use to ensure proper parsing and execution.
## Tool Use Rules
Here are the rules you should always follow to solve your task:
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
3. If no tool call is needed, just answer the question directly.
4. Never re-do a tool call that you previously did with the exact same parameters.
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
{{ TOOLS_INFO }}
## Response rules
Respond in the language of the user's query, unless the user instructions specify additional requirements for the language to be used.
# User Instructions
{{ USER_SYSTEM_PROMPT }}
`
/**
* 默认工具使用示例(提取自 Cherry Studio
*/
const DEFAULT_TOOL_USE_EXAMPLES = `
Here are a few examples using notional tools:
---
User: Generate an image of the oldest person in this document.
A: I can use the document_qa tool to find out who the oldest person is in the document.
<tool_use>
<name>document_qa</name>
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
</tool_use>
User: <tool_use_result>
<name>document_qa</name>
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
</tool_use_result>
A: I can use the image_generator tool to create a portrait of John Doe.
<tool_use>
<name>image_generator</name>
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
</tool_use>
User: <tool_use_result>
<name>image_generator</name>
<result>image.png</result>
</tool_use_result>
A: the image is generated as image.png
---
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
A: I can use the python_interpreter tool to calculate the result of the operation.
<tool_use>
<name>python_interpreter</name>
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
</tool_use>
User: <tool_use_result>
<name>python_interpreter</name>
<result>1302.678</result>
</tool_use_result>
A: The result of the operation is 1302.678.
---
User: "Which city has the highest population , Guangzhou or Shanghai?"
A: I can use the search tool to find the population of Guangzhou.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Guangzhou"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
</tool_use_result>
A: I can use the search tool to find the population of Shanghai.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Shanghai"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>26 million (2019)</result>
</tool_use_result>
A: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
/**
* 构建可用工具部分(提取自 Cherry Studio
*/
function buildAvailableTools(tools: ToolSet): string | null {
const availableTools = Object.keys(tools)
if (availableTools.length === 0) return null
const result = availableTools
.map((toolName: string) => {
const tool = tools[toolName]
return `
<tool>
<name>${toolName}</name>
<description>${tool.description || ''}</description>
<arguments>
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
</arguments>
</tool>
`
})
.join('\n')
return `<tools>
${result}
</tools>`
}
/**
* 默认的系统提示符构建函数(提取自 Cherry Studio
*/
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet, mcpMode?: string): string {
const availableTools = buildAvailableTools(tools)
if (availableTools === null) return userSystemPrompt
if (mcpMode == 'auto') {
return DEFAULT_SYSTEM_PROMPT.replace('{{ TOOLS_INFO }}', '').replace(
'{{ USER_SYSTEM_PROMPT }}',
userSystemPrompt || ''
)
}
const toolsInfo = `
## Tool Use Examples
{{ TOOL_USE_EXAMPLES }}
## Tool Use Available Tools
Above example were using notional tools that might not exist for you. You only have access to these tools:
{{ AVAILABLE_TOOLS }}`
.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOLS_INFO }}', toolsInfo).replace(
'{{ USER_SYSTEM_PROMPT }}',
userSystemPrompt || ''
)
return fullPrompt
}
/**
* 默认工具解析函数(提取自 Cherry Studio
* 解析 XML 格式的工具调用
*/
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
if (!content || !tools || Object.keys(tools).length === 0) {
return { results: [], content: content }
}
// 支持两种格式:
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
// 2. 只有内部内容(从 TagExtractor 提取出来的)
let contentToProcess = content
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
if (!content.includes('<tool_use>')) {
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
}
const toolUsePattern =
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
const results: ToolUseResult[] = []
let match
let idx = 0
// Find all tool use blocks
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
const fullMatch = match[0]
let toolName = match[2].trim()
switch (toolName.toLowerCase()) {
case 'search':
toolName = 'mcp__CherryHub__search'
break
case 'exec':
toolName = 'mcp__CherryHub__exec'
break
default:
break
}
const toolArgs = match[4].trim()
// Try to parse the arguments as JSON
let parsedArgs
try {
parsedArgs = JSON.parse(toolArgs)
} catch (error) {
// If parsing fails, use the string as is
parsedArgs = toolArgs
}
// Find the corresponding tool
const tool = tools[toolName]
if (!tool) {
console.warn(`Tool "${toolName}" not found in available tools`)
continue
}
// Add to results array
results.push({
id: `${toolName}-${idx++}`, // Unique ID for each tool use
toolName: toolName,
arguments: parsedArgs,
status: 'pending'
})
contentToProcess = contentToProcess.replace(fullMatch, '')
}
return { results, content: contentToProcess }
}
export const createPromptToolUsePlugin = (
config: PromptToolUseConfig = {}
): AiPlugin<StreamTextParams, StreamTextResult> => {
const {
enabled = true,
buildSystemPrompt = defaultBuildSystemPrompt,
parseToolUse = defaultParseToolUse,
mcpMode
} = config
return definePlugin<StreamTextParams, StreamTextResult>({
name: 'built-in:prompt-tool-use',
transformParams: (params, context) => {
if (!enabled || !params.tools || typeof params.tools !== 'object') {
return params
}
// 分离 provider 和其他类型的工具
const providerDefinedTools: ToolSet = {}
const promptTools: ToolSet = {}
for (const [toolName, tool] of Object.entries(params.tools)) {
if (tool.type === 'provider') {
// provider 类型的工具保留在 tools 参数中
providerDefinedTools[toolName] = tool
} else {
// 其他工具转换为 prompt 模式
promptTools[toolName] = tool
}
}
// 只有当有非 provider 工具时才保存到 context
if (Object.keys(promptTools).length > 0) {
context.mcpTools = promptTools
}
// 递归调用时,不重新构建 system prompt避免重复追加工具定义
if (context.isRecursiveCall) {
const transformedParams = {
...params,
tools: Object.keys(providerDefinedTools).length > 0 ? providerDefinedTools : undefined
}
context.originalParams = transformedParams
return transformedParams
}
// 构建系统提示符(只包含非 provider 工具)
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, promptTools, mcpMode)
// 保留 provide tools移除其他 tools
const transformedParams = {
...params,
...(systemPrompt ? { system: systemPrompt } : {}),
tools: Object.keys(providerDefinedTools).length > 0 ? providerDefinedTools : undefined
}
context.originalParams = transformedParams
return transformedParams
},
transformStream: (_, context) => () => {
let textBuffer = ''
// let stepId = ''
// 如果没有需要 prompt 模式处理的工具,直接返回原始流
if (!context.mcpTools) {
return new TransformStream()
}
// 初始化 usage 累加器和工具执行状态
if (!context.accumulatedUsage) {
context.accumulatedUsage = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
}
}
if (context.hasExecutedToolsInCurrentStep === undefined) {
context.hasExecutedToolsInCurrentStep = false
}
// 创建工具执行器、流事件管理器和标签提取器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
// 用于hold text-start事件直到确认有非工具标签内容
let pendingTextStart: TextStreamPart<TOOLS> | null = null
let hasStartedText = false
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// Hold住text-start事件直到确认有非工具标签内容
if ((chunk as any).type === 'text-start') {
pendingTextStart = chunk
return
}
// text-delta阶段收集文本内容并过滤工具标签
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
// stepId = chunk.id || ''
// 使用TagExtractor过滤工具标签只传递非标签内容到UI层
const extractionResults = tagExtractor.processText(chunk.text || '')
for (const result of extractionResults) {
// 只传递非标签内容到UI层
if (!result.isTagContent && result.content) {
// 如果还没有发送text-start且有pending的text-start先发送它
if (!hasStartedText && pendingTextStart) {
controller.enqueue(pendingTextStart)
hasStartedText = true
pendingTextStart = null
}
const filteredChunk = {
...chunk,
text: result.content
}
controller.enqueue(filteredChunk)
}
}
return
}
if (chunk.type === 'text-end') {
// 只有当已经发送了text-start时才发送text-end
if (hasStartedText) {
controller.enqueue(chunk)
}
return
}
if (chunk.type === 'finish-step') {
// 统一在finish-step阶段检查并执行工具调用
const tools = context.mcpTools
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
// 解析完整的textBuffer来检测工具调用
const { results: parsedTools } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
if (validToolUses.length > 0) {
context.hasExecutedToolsInCurrentStep = true
// 执行工具调用(不需要手动发送 start-step外部流已经处理
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
// 处理递归调用
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
return
}
}
// 如果没有执行工具调用,累加 usage 后透传 finish-step 事件
if (chunk.usage && context.accumulatedUsage) {
streamEventManager.accumulateUsage(context.accumulatedUsage, chunk.usage)
}
controller.enqueue(chunk)
// 清理状态
textBuffer = ''
return
}
// 处理 finish 类型,使用累加后的 totalUsage
if (chunk.type === 'finish') {
controller.enqueue({
...chunk,
totalUsage: context.accumulatedUsage
})
return
}
// 对于其他类型的事件直接传递不包括text-start已在上面处理
if ((chunk as any).type !== 'text-start') {
controller.enqueue(chunk)
}
},
flush() {
// 清理pending状态
pendingTextStart = null
hasStartedText = false
}
})
}
})
}

View File

@@ -1,196 +0,0 @@
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
/**
* Returns the index of the start of the searchedText in the text, or null if it
* is not found.
*/
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
// Return null immediately if searchedText is empty.
if (searchedText.length === 0) {
return null
}
// Check if the searchedText exists as a direct substring of text.
const directIndex = text.indexOf(searchedText)
if (directIndex !== -1) {
return directIndex
}
// Otherwise, look for the largest suffix of "text" that matches
// a prefix of "searchedText". We go from the end of text inward.
for (let i = text.length - 1; i >= 0; i--) {
const suffix = text.substring(i)
if (searchedText.startsWith(suffix)) {
return i
}
}
return null
}
export interface TagConfig {
openingTag: string
closingTag: string
separator?: string
}
export interface TagExtractionState {
textBuffer: string
isInsideTag: boolean
isFirstTag: boolean
isFirstText: boolean
afterSwitch: boolean
accumulatedTagContent: string
hasTagContent: boolean
}
export interface TagExtractionResult {
content: string
isTagContent: boolean
complete: boolean
tagContentExtracted?: string
}
/**
* 通用标签提取处理器
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
*/
export class TagExtractor {
private config: TagConfig
private state: TagExtractionState
constructor(config: TagConfig) {
this.config = config
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
/**
* 处理文本块,返回处理结果
*/
processText(newText: string): TagExtractionResult[] {
this.state.textBuffer += newText
const results: TagExtractionResult[] = []
// 处理标签提取逻辑
while (true) {
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
if (startIndex == null) {
const content = this.state.textBuffer
if (content.length > 0) {
results.push({
content: this.addPrefix(content),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(content)
this.state.hasTagContent = true
}
}
this.state.textBuffer = ''
break
}
// 处理标签前的内容
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
if (contentBeforeTag.length > 0) {
results.push({
content: this.addPrefix(contentBeforeTag),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
this.state.hasTagContent = true
}
}
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
if (foundFullMatch) {
// 如果找到完整的标签
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
if (this.state.isInsideTag && this.state.hasTagContent) {
results.push({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
})
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
}
this.state.isInsideTag = !this.state.isInsideTag
this.state.afterSwitch = true
if (this.state.isInsideTag) {
this.state.isFirstTag = false
} else {
this.state.isFirstText = false
}
} else {
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
break
}
}
return results
}
/**
* 完成处理,返回任何剩余的标签内容
*/
finalize(): TagExtractionResult | null {
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
const result = {
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
}
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
return result
}
return null
}
private addPrefix(text: string): string {
const needsPrefix =
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
this.state.afterSwitch = false
return prefix + text
}
/**
* 重置状态
*/
reset(): void {
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
}

View File

@@ -1,33 +0,0 @@
import type { ToolSet } from 'ai'
import type { AiRequestContext } from '../..'
/**
* 解析结果类型
* 表示从AI响应中解析出的工具使用意图
*/
export interface ToolUseResult {
id: string
toolName: string
arguments: any
status: 'pending' | 'invoking' | 'done' | 'error'
}
export interface BaseToolUsePluginConfig {
enabled?: boolean
}
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
mcpMode?: string
}
/**
* 扩展的 AI 请求上下文,支持 MCP 工具存储
*/
export interface ToolUseRequestContext extends AiRequestContext {
mcpTools: ToolSet
}

View File

@@ -21,7 +21,6 @@ export interface AiRequestMetadata {
enableReasoning?: boolean
enableWebSearch?: boolean
enableGenerateImage?: boolean
isPromptToolUse?: boolean
isSupportedToolUse?: boolean
// 自定义元数据,使用 JSONValue 确保类型安全
custom?: JSONObject

View File

@@ -1,5 +1,5 @@
import type { ProviderV3 } from '@ai-sdk/provider'
import { LRUCache } from 'lru-cache'
import QuickLRU from 'quick-lru'
import { deepMergeObjects } from '../../utils'
import type { ProviderVariant, ToolFactoryMap } from '../types'
@@ -116,7 +116,7 @@ export class ProviderExtension<
>
> {
/** Provider 实例缓存 - 按 settings hash 存储LRU 自动清理 */
private instances: LRUCache<string, TProvider>
private instances: QuickLRU<string, TProvider>
/** In-flight promise map - 防止并发创建相同 settings 的 provider */
private pendingCreations: Map<string, Promise<TProvider>> = new Map()
@@ -126,9 +126,8 @@ export class ProviderExtension<
throw new Error('ProviderExtension: name is required')
}
this.instances = new LRUCache<string, TProvider>({
max: 10,
updateAgeOnGet: true
this.instances = new QuickLRU<string, TProvider>({
maxSize: 10
})
}
@@ -161,10 +160,13 @@ export class ProviderExtension<
return 'default'
}
const seen = new WeakSet()
const stableStringify = (obj: any): string => {
if (obj === null || obj === undefined) return 'null'
if (typeof obj === 'function') return '"[function]"'
if (typeof obj !== 'object') return JSON.stringify(obj)
if (seen.has(obj)) return '"[circular]"'
seen.add(obj)
if (Array.isArray(obj)) return `[${obj.map(stableStringify).join(',')}]`
const keys = Object.keys(obj).sort()
@@ -324,10 +326,10 @@ export class ProviderExtension<
* 获取已缓存的 provider 实例(如果存在)
*/
getCachedProvider(): TProvider | undefined {
for (const [key, value] of this.instances.entries()) {
for (const [key, value] of this.instances) {
if (!key.includes(':')) return value
}
for (const [, value] of this.instances.entries()) {
for (const [, value] of this.instances) {
return value
}
return undefined

View File

@@ -2,7 +2,7 @@
* Runtime 层类型定义
*/
import type { EmbeddingModelV3, ImageModelV3, ProviderV3 } from '@ai-sdk/provider'
import type { embedMany, generateImage, generateText, streamText } from 'ai'
import type { embedMany, Experimental_DownloadFunction, generateImage, generateText, streamText } from 'ai'
import { type AiPlugin } from '../plugins'
import type { CoreProviderSettingsMap, StringKeys } from '../providers/types'
@@ -31,6 +31,7 @@ export interface RuntimeConfig<
export type generateImageParams = Omit<Parameters<typeof generateImage>[0], 'model'> & {
model: string | ImageModelV3
experimental_download?: Experimental_DownloadFunction
}
export type generateImageResult = Awaited<ReturnType<typeof generateImage>>
export type generateTextParams = Parameters<typeof generateText>[0]

View File

@@ -1,7 +1,6 @@
import { W3CTraceContextPropagator } from '@opentelemetry/core'
import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http'
import type { SpanProcessor } from '@opentelemetry/sdk-trace-base'
import { BatchSpanProcessor, ConsoleSpanExporter } from '@opentelemetry/sdk-trace-base'
import { ConsoleSpanExporter, SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base'
import { WebTracerProvider } from '@opentelemetry/sdk-trace-web'
import type { TraceConfig } from '../trace-core/types/config'
@@ -21,7 +20,10 @@ export class WebTracer {
defaultConfig.headers = config.headers || defaultConfig.headers
defaultConfig.defaultTracerName = config.defaultTracerName || defaultConfig.defaultTracerName
}
this.processor = spanProcessor || new BatchSpanProcessor(this.getExporter())
// Callers are expected to pass a processor. The dev-only fallback logs
// spans to the console so that a misconfigured caller doesn't silently
// lose data when callers forget to inject a processor.
this.processor = spanProcessor || new SimpleSpanProcessor(new ConsoleSpanExporter())
this.provider = new WebTracerProvider({
spanProcessors: [this.processor]
})
@@ -30,16 +32,6 @@ export class WebTracer {
contextManager: contextManager
})
}
private static getExporter() {
if (defaultConfig.endpoint) {
return new OTLPTraceExporter({
url: `${defaultConfig.endpoint}/v1/traces`,
headers: defaultConfig.headers
})
}
return new ConsoleSpanExporter()
}
}
export const startContext = contextManager.startContextForTopic.bind(contextManager)

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,8 @@
import { describe, expect, it } from 'vitest'
import { buildRuntimeEndpointConfigs, lookupRegistryModel } from '../registry-utils'
import { buildRuntimeEndpointConfigs, inferAdapterFamily, lookupRegistryModel } from '../registry-utils'
import { ENDPOINT_TYPE } from '../schemas/enums'
import type { ModelConfig } from '../schemas/model'
import type { RegistryEndpointConfig } from '../schemas/provider'
import type { ProviderModelOverride } from '../schemas/provider-models'
@@ -187,4 +188,58 @@ describe('buildRuntimeEndpointConfigs', () => {
} as Record<string, RegistryEndpointConfig>)
expect(result).toBeNull()
})
it('copies adapterFamily through to runtime config', () => {
const result = buildRuntimeEndpointConfigs({
'openai-chat-completions': { baseUrl: 'https://x', adapterFamily: 'openai-compatible' },
'anthropic-messages': { baseUrl: 'https://y', adapterFamily: 'anthropic' }
} as Record<string, RegistryEndpointConfig>)
expect(result!['openai-chat-completions'].adapterFamily).toBe('openai-compatible')
expect(result!['anthropic-messages'].adapterFamily).toBe('anthropic')
})
it('adapterFamily alone is enough to retain an endpoint config', () => {
const result = buildRuntimeEndpointConfigs({
'openai-chat-completions': { adapterFamily: 'openai-compatible' }
} as Record<string, RegistryEndpointConfig>)
expect(result!['openai-chat-completions'].adapterFamily).toBe('openai-compatible')
expect(result!['openai-chat-completions'].baseUrl).toBeUndefined()
})
})
// ─────────────────────────────────────────────────────────────────────────────
// inferAdapterFamily
// ─────────────────────────────────────────────────────────────────────────────
describe('inferAdapterFamily', () => {
it('catalog adapterFamily wins over endpoint default', () => {
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES, { adapterFamily: 'aihubmix' })).toBe('aihubmix')
})
it('falls back to endpoint default when catalog has no adapterFamily', () => {
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES, {})).toBe('anthropic')
})
it('falls back to endpoint default when catalog is absent', () => {
expect(inferAdapterFamily(ENDPOINT_TYPE.ANTHROPIC_MESSAGES)).toBe('anthropic')
expect(inferAdapterFamily(ENDPOINT_TYPE.GOOGLE_GENERATE_CONTENT)).toBe('google')
expect(inferAdapterFamily(ENDPOINT_TYPE.OLLAMA_CHAT)).toBe('ollama')
expect(inferAdapterFamily(ENDPOINT_TYPE.OLLAMA_GENERATE)).toBe('ollama')
expect(inferAdapterFamily(ENDPOINT_TYPE.JINA_RERANK)).toBe('jina-rerank')
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_RESPONSES)).toBe('openai')
})
it('falls back to openai-compatible for endpoints with no specific default', () => {
// openai-chat-completions is intentionally generic — many vendors speak it
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_CHAT_COMPLETIONS)).toBe('openai-compatible')
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_IMAGE_GENERATION)).toBe('openai-compatible')
})
it('accepts both RegistryEndpointConfig and RuntimeEndpointConfig shapes', () => {
// Both schemas have adapterFamily — the function only needs to peek that
// one field so the input type is structural.
expect(inferAdapterFamily(ENDPOINT_TYPE.OPENAI_CHAT_COMPLETIONS, { adapterFamily: 'groq' })).toBe('groq')
})
})

View File

@@ -63,7 +63,15 @@ export { normalizeModelId } from './utils/normalize'
// Pure lookup and transformation utilities (no fs dependency)
export type { ModelLookupResult, RuntimeEndpointConfig } from './registry-utils'
export { buildRuntimeEndpointConfigs, lookupRegistryModel, lookupRegistryProvider } from './registry-utils'
export {
buildRuntimeEndpointConfigs,
inferAdapterFamily,
lookupRegistryModel,
lookupRegistryProvider
} from './registry-utils'
// Shared vendor identity regex, used by shared model helpers.
export { VENDOR_PATTERNS } from './patterns/vendor-patterns'
// Shared vendor identity regex — consumed by @shared capability inference
// and @cherrystudio/ui icon routing. Single source of truth for "which
// vendor does this raw model ID belong to".
export type { VendorKey } from './patterns/vendor-patterns'
export { isVendor, matchVendor, VENDOR_PATTERNS } from './patterns/vendor-patterns'

View File

@@ -2,13 +2,17 @@
* Vendor identity regex patterns — the single source of truth for
* "which vendor does this raw model ID belong to".
*
* Shared by:
* Shared across three call sites:
* - `@shared/utils/model` — vendor check functions (`isAnthropicModel`
* etc.) and capability inference (e.g. deciding which IDs to mark
* `REASONING` in the schema).
* - `@cherrystudio/ui` icon registry — vendor-level icon routing for
* models whose ID doesn't have a dedicated SKU icon.
* - Future callers doing vendor dispatch.
*
* Keeping these regex in the registry layer lets model capability
* inference use provider-owned vendor taxonomy instead of renderer config.
* Keeping these regex in the registry layer means both capability
* inference and icon lookup stay in lockstep when a new vendor /
* naming convention lands.
*
* Scope: **vendor identity only**. SKU-level patterns (`gpt-5.1-codex-mini`,
* `claude-sonnet-4-6`, etc.) stay in their specific consumer modules —
@@ -80,3 +84,27 @@ export const VENDOR_PATTERNS = {
/** Mistral family */
mistral: /mistral|pixtral|codestral|ministral|voxtral|devstral|mixtral|magistral/i
} as const satisfies Record<string, RegExp>
export type VendorKey = keyof typeof VENDOR_PATTERNS
/**
* Return the vendor slug for a normalized model ID, or `undefined` if
* no vendor pattern matches. Iteration order is stable (key insertion
* order) but not semantically important — patterns don't overlap.
*/
export function matchVendor(normalizedId: string): VendorKey | undefined {
for (const [vendor, pattern] of Object.entries(VENDOR_PATTERNS) as [VendorKey, RegExp][]) {
if (pattern.test(normalizedId)) return vendor
}
return undefined
}
/**
* Lightweight vendor predicate factory. Exported primarily so consumers
* can spell the check as `isVendor('anthropic')(id)` when composing
* higher-level logic.
*/
export function isVendor(vendor: VendorKey): (normalizedId: string) => boolean {
const pattern = VENDOR_PATTERNS[vendor]
return (id: string) => pattern.test(id)
}

View File

@@ -184,6 +184,11 @@ export class RegistryLoader {
return this.modelById!.get(modelId) ?? this.modelByNormId!.get(normalizeModelId(modelId)) ?? null
}
findProvider(providerId: string): ProviderConfig | null {
const providers = this.loadProviders()
return providers.find((p) => p.id === providerId) ?? null
}
findOverride(providerId: string, modelId: string): ProviderModelOverride | null {
this.loadProviderModels()
const key = `${providerId}::${modelId}`

View File

@@ -3,6 +3,7 @@
* Safe to import from browser/renderer contexts.
*/
import { ENDPOINT_TYPE, type EndpointType } from './schemas/enums'
import type { ModelConfig } from './schemas/model'
import type { ProviderConfig, RegistryEndpointConfig } from './schemas/provider'
import type { ProviderModelOverride } from './schemas/provider-models'
@@ -51,6 +52,7 @@ export interface RuntimeEndpointConfig {
baseUrl?: string
modelsApiUrls?: { default?: string; embedding?: string; reranker?: string }
reasoningFormatType?: string
adapterFamily?: string
}
/**
@@ -70,9 +72,45 @@ export function buildRuntimeEndpointConfigs(
if (regConfig.baseUrl) config.baseUrl = regConfig.baseUrl
if (regConfig.modelsApiUrls) config.modelsApiUrls = regConfig.modelsApiUrls
if (regConfig.reasoningFormat?.type) config.reasoningFormatType = regConfig.reasoningFormat.type
if (regConfig.adapterFamily) config.adapterFamily = regConfig.adapterFamily
if (Object.keys(config).length > 0) configs[k] = config
}
return Object.keys(configs).length > 0 ? configs : null
}
/**
* Default AI SDK adapter family per endpoint type. Used when the catalog
* doesn't specify one and no more-specific signal (e.g. legacy provider type)
* is available. The mapping is purely protocol-derived — any endpoint that
* speaks anthropic-messages format needs the `anthropic` adapter, etc.
*/
const ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY: Partial<Record<EndpointType, string>> = {
[ENDPOINT_TYPE.ANTHROPIC_MESSAGES]: 'anthropic',
[ENDPOINT_TYPE.GOOGLE_GENERATE_CONTENT]: 'google',
[ENDPOINT_TYPE.OLLAMA_CHAT]: 'ollama',
[ENDPOINT_TYPE.OLLAMA_GENERATE]: 'ollama',
[ENDPOINT_TYPE.JINA_RERANK]: 'jina-rerank',
[ENDPOINT_TYPE.OPENAI_RESPONSES]: 'openai'
}
/**
* Compute the AI SDK adapter family for an endpoint. Single source of truth
* for seeder / migrator / UI creation paths — `adapterFamily` is a derived,
* write-time value; the runtime resolver only reads it.
*
* 1. Catalog `adapterFamily` wins when present (encodes vendor-specific
* relay routing like `aihubmix` for anthropic-messages on AiHubMix).
* 2. Otherwise, fall back to the endpoint-type default
* (`anthropic-messages` → `anthropic`, etc.).
* 3. Final fallback `openai-compatible` covers `openai-chat-completions`
* and any future openai-protocol endpoint without a more specific match.
*/
export function inferAdapterFamily(
endpointType: EndpointType,
catalogConfig?: Pick<RegistryEndpointConfig, 'adapterFamily'> | Pick<RuntimeEndpointConfig, 'adapterFamily'> | null
): string {
if (catalogConfig?.adapterFamily) return catalogConfig.adapterFamily
return ENDPOINT_TYPE_TO_DEFAULT_ADAPTER_FAMILY[endpointType] ?? 'openai-compatible'
}

View File

@@ -31,10 +31,8 @@ export const ApiFeaturesSchema = z.object({
developerRole: z.boolean().default(false),
/** Whether the provider supports service tier selection (OpenAI/Groq-specific) */
serviceTier: z.boolean().default(false),
/** Whether the provider supports verbosity settings (Gemini-specific) */
verbosity: z.boolean().default(false),
/** Whether the provider supports enable_thinking parameter */
enableThinking: z.boolean().default(true)
/** Whether the provider supports verbosity settings (OpenAI-specific) */
verbosity: z.boolean().default(false)
})
// ═══════════════════════════════════════════════════════════════════════════════
@@ -197,7 +195,13 @@ export const RegistryEndpointConfigSchema = z.object({
})
.optional(),
/** How this endpoint type expects reasoning parameters to be formatted */
reasoningFormat: ProviderReasoningFormatSchema.optional()
reasoningFormat: ProviderReasoningFormatSchema.optional(),
/**
* AI SDK adapter family that handles this endpoint. Aligns with the IDs
* registered in `appProviderIds`. Resolvers should prefer this over
* heuristic id/baseUrl inference when present.
*/
adapterFamily: z.string().optional()
})
export const ProviderConfigSchema = z

View File

@@ -300,7 +300,7 @@ export function EntitySelector<T extends EntityItemBase>(props: EntitySelectorPr
userOnOpenAutoFocus?.(event)
}}
className={cn(
'flex max-h-[var(--radix-popover-content-available-height)] flex-col overflow-hidden rounded-2xs border-border/60 bg-popover p-0 shadow-lg',
'flex max-h-[var(--radix-popover-content-available-height)] flex-col overflow-hidden rounded-lg border-border/60 bg-popover p-0 shadow-lg',
userPopoverClassName,
className
)}>

View File

@@ -2,6 +2,13 @@ import { MODEL_ICON_CATALOG, type ModelIconKey } from './models/catalog'
import { PROVIDER_ICON_CATALOG, type ProviderIconKey } from './providers/catalog'
import type { CompoundIcon } from './types'
// NOTE: the vendor-level regex below duplicate `@cherrystudio/provider-registry`'s
// `VENDOR_PATTERNS` (anthropic, gemini, gemma, grok, doubao, hunyuan, kimi, zhipu,
// mimo, ling, qwen). Kept in sync manually until UI's build surface lets us
// import from `@cherrystudio/provider-registry` directly. When adding / tweaking
// a vendor pattern, update BOTH places — or, better, fix the UI → registry import
// story and swap these inline regex for `VENDOR_PATTERNS.<vendor>`.
/**
* Model ID regex patterns mapped to MODEL_ICON_CATALOG keys.
* Order matters: more specific patterns must come before general ones.

1238
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,14 @@ const workspaceConfigPath = path.join(__dirname, '..', 'pnpm-workspace.yaml')
// if you want to add new prebuild binaries packages with different architectures, you can add them here
// please add to allX64 and allArm64 from pnpm-lock.yaml
const packages = [
'@anthropic-ai/claude-agent-sdk-darwin-arm64',
'@anthropic-ai/claude-agent-sdk-darwin-x64',
'@anthropic-ai/claude-agent-sdk-linux-arm64',
'@anthropic-ai/claude-agent-sdk-linux-arm64-musl',
'@anthropic-ai/claude-agent-sdk-linux-x64',
'@anthropic-ai/claude-agent-sdk-linux-x64-musl',
'@anthropic-ai/claude-agent-sdk-win32-arm64',
'@anthropic-ai/claude-agent-sdk-win32-x64',
'@img/sharp-darwin-arm64',
'@img/sharp-darwin-x64',
'@img/sharp-libvips-darwin-arm64',
@@ -132,7 +140,7 @@ exports.default = async function (context) {
}
return f !== `${arch}-${platform}`
})
.map((f) => '!node_modules/@anthropic-ai/claude-agent-sdk/vendor/ripgrep/' + f + '/**')
.map((f) => '!node_modules/@cherrystudio/ripgrep/vendor/ripgrep/' + f + '/**')
// Exclude rtk binaries for other platform-arch combinations
const currentPlatformKey = `${platform}-${arch}`

View File

@@ -1,52 +0,0 @@
import { builtinModules } from 'node:module'
import { resolve } from 'path'
import { build as viteBuild, type Plugin } from 'vite'
interface BuildProxyBootstrapPluginOptions {
dependencies: string[]
isProd: boolean
rootDir: string
}
export const buildProxyBootstrapPlugin = ({
dependencies,
isProd,
rootDir
}: BuildProxyBootstrapPluginOptions): Plugin => {
return {
name: 'cherry-build-proxy-bootstrap',
apply: 'build',
async closeBundle() {
await viteBuild({
configFile: false,
publicDir: false,
resolve: {
mainFields: ['module', 'jsnext:main', 'jsnext'],
conditions: ['node']
},
build: {
outDir: resolve(rootDir, 'out/proxy'),
target: 'node22',
minify: false,
reportCompressedSize: false,
copyPublicDir: false,
lib: {
entry: resolve(rootDir, 'src/main/services/proxy/bootstrap.ts'),
formats: ['cjs'],
fileName: () => 'index.js'
},
rollupOptions: {
external: [
'electron',
/^electron\/.+/,
...builtinModules.flatMap((moduleName) => [moduleName, `node:${moduleName}`]),
...dependencies
]
}
},
esbuild: isProd ? { legalComments: 'none' } : {}
})
}
}
}

View File

@@ -35,10 +35,8 @@ const { mockSyncBuiltinSkill } = vi.hoisted(() => ({
mockSyncBuiltinSkill: vi.fn()
}))
vi.mock('../services/agents/skills/SkillService', () => ({
skillService: {
syncBuiltinSkill: mockSyncBuiltinSkill
}
vi.mock('@main/ai/skills/SkillService', () => ({
skillService: { syncBuiltinSkill: mockSyncBuiltinSkill }
}))
// Matches the stub in tests/main.setup.ts → mockApplicationFactory().getPath

591
src/main/ai/AiService.ts Normal file
View File

@@ -0,0 +1,591 @@
import { embedMany as aiCoreEmbedMany, generateImage as aiCoreGenerateImage } from '@cherrystudio/ai-core'
import { assistantDataService } from '@data/services/AssistantService'
import type { PersonGeneration } from '@google/genai'
import { loggerService } from '@logger'
import { application } from '@main/core/application'
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
import { messageService } from '@main/data/services/MessageService'
import { modelService } from '@main/data/services/ModelService'
import { providerService } from '@main/data/services/ProviderService'
import { type TranslateOpenRequest, translateService } from '@main/services/translate/translateService'
import { downloadImageAsBase64 } from '@main/utils/downloadAsBase64'
import { applyApprovalDecisions } from '@shared/ai/transport'
import { type Assistant } from '@shared/data/types/assistant'
import type { FileEntry } from '@shared/data/types/file/fileEntry'
import { type Model, parseUniqueModelId } from '@shared/data/types/model'
import type { Base64String } from '@shared/file/types/common'
import { IpcChannel } from '@shared/IpcChannel'
import { isEmbeddingModel } from '@shared/utils/model'
import {
type EmbeddingModelUsage,
isToolUIPart,
type LanguageModelUsage,
type ModelMessage,
type UIMessageChunk
} from 'ai'
import * as z from 'zod'
import { isAgentSessionTopic } from './agentSession/topic'
import { resolveUIMessageFileUrls } from './messages/messageConverter'
import { listModels as listModelsFromProvider } from './provider/listModels'
import { Agent } from './runtime/aiSdk/Agent'
import type { AgentLoopHooks } from './runtime/aiSdk/loop'
import { mergeUsage, ZERO_USAGE } from './runtime/aiSdk/observers/usage'
import { buildAgentParams } from './runtime/aiSdk/params/buildAgentParams'
import type { RequestFeature } from './runtime/aiSdk/params/feature'
import { WebContentsListener } from './streamManager/listeners/WebContentsListener'
import { registerBuiltinTools } from './tools/adapters/aiSdk/builtin'
import type { AppProviderSettingsMap } from './types'
import type { AiBaseRequest, AiStreamRequest, AiTransportOptions, ListModelsRequest } from './types/requests'
import { buildImageProviderOptions, normalizeAspectRatio } from './utils/imageOptions'
const logger = loggerService.withContext('AiService')
// ── Request types ──────────────────────────────────────────────────
/** In-process variant of `AiTransportOptions` — adds `signal`, which is not IPC-serialisable. */
export interface AiRequestOptions extends AiTransportOptions {
/** In-process only. Renderer payloads use `AiTransportOptions` (no signal). */
signal?: AbortSignal
}
/** Widens `requestOptions` to accept the in-process shape on `AiService.*` method signatures. */
export type AsInProcess<T extends AiBaseRequest> = Omit<T, 'requestOptions'> & {
requestOptions?: AiRequestOptions
}
/** Non-streaming text generation request — pure transport data. */
export interface AiGenerateRequest extends AiBaseRequest {
system?: string
prompt?: string
messages?: ModelMessage[]
}
// ── SDK extensions ─────────────────────────────────────────────────
/** Result of non-streaming text generation. */
export interface AiGenerateResult {
text: string
usage?: LanguageModelUsage
}
/** Image generation request. */
export interface AiImageRequest extends AiBaseRequest {
prompt: string
/** Input images for editing (base64 data URLs or URLs). If provided, uses edit mode. */
inputImages?: string[]
/** Mask for inpainting (only with inputImages). */
mask?: string
n?: number
size?: string
negativePrompt?: string
seed?: number
quality?: string
numInferenceSteps?: number
guidanceScale?: number
promptEnhancement?: boolean
personGeneration?: PersonGeneration
aspectRatio?: string
background?: string
moderation?: string
style?: string
/** Vendor-specific image params keyed by provider id; mapped to AI SDK provider options in main. */
providerOptions?: Record<string, Record<string, unknown>>
}
/** Image generation result — persisted file entries (main writes the bytes). */
export interface AiImageResult {
files: FileEntry[]
}
/** Embedding request. */
export interface AiEmbedRequest extends AiBaseRequest {
values: string[]
}
/** Embedding result. */
export interface AiEmbedResult {
embeddings: number[][]
usage?: EmbeddingModelUsage
}
/** Validates the `Ai_ToolApproval_Respond` IPC payload at the renderer boundary. */
const ToolApprovalRespondSchema = z.object({
approvalId: z.string().min(1),
approved: z.boolean(),
reason: z.string().optional(),
updatedInput: z.record(z.string(), z.unknown()).optional(),
topicId: z.string().optional(),
anchorId: z.string().optional()
})
// ── Service ────────────────────────────────────────────────────────
/**
* Lifecycle AI service. See `docs/references/ai/core-architecture.md`.
*
* DO NOT mirror `@DependsOn(['AiService'])` on AiStreamManager —
* `runExecutionLoop` looks AiService up at runtime, and every `send()`
* caller routes through AiService first.
*/
@Injectable('AiService')
@ServicePhase(Phase.WhenReady)
@DependsOn(['McpRuntimeService', 'McpCatalogService', 'AiStreamManager'])
export class AiService extends BaseService {
// Per-request AbortControllers for `Ai_GenerateImage`, paired with the
// `Ai_AbortImage` channel. Key is the renderer-generated requestId
// (see `src/preload/index.ts`). Entries are self-cleaning via the
// handler's `finally` block; abort on an unknown id is a no-op.
// TODO(abort-registry): collapse with MCP/stream/LAN registries once
// the shared `ipcHandleWithAbort` helper lands.
private readonly imageRequests = new Map<string, AbortController>()
protected async onInit(): Promise<void> {
registerBuiltinTools()
this.registerIpcHandlers()
logger.info('AiService initialized')
}
private registerIpcHandlers(): void {
this.ipcHandle(IpcChannel.Ai_GenerateText, async (_, request: AiGenerateRequest) => {
return this.generateText(request)
})
this.ipcHandle(IpcChannel.Ai_CheckModel, async (_, request: AiBaseRequest & { timeout?: number }) => {
return this.checkModel(request)
})
this.ipcHandle(IpcChannel.Ai_EmbedMany, async (_, request: AiEmbedRequest) => {
return this.embedMany(request)
})
this.ipcHandle(IpcChannel.Ai_GenerateImage, async (_, request: { requestId: string; payload: AiImageRequest }) => {
const { requestId, payload } = request
const controller = new AbortController()
this.imageRequests.set(requestId, controller)
try {
return await this.generateImage({
...payload,
requestOptions: { ...payload.requestOptions, signal: controller.signal }
})
} finally {
this.imageRequests.delete(requestId)
}
})
this.ipcOn(IpcChannel.Ai_AbortImage, (_, request: { requestId: string }) => {
this.imageRequests.get(request.requestId)?.abort()
})
this.ipcHandle(IpcChannel.Ai_ListModels, async (_, request: ListModelsRequest) => {
return this.listModels(request)
})
this.ipcHandle(IpcChannel.Ai_Translate_Open, async (event, request: TranslateOpenRequest) => {
return translateService.open(event.sender, request)
})
this.ipcHandle(IpcChannel.Ai_ToolApproval_Respond, async (event, rawPayload: unknown): Promise<{ ok: boolean }> => {
// Validate the renderer payload at the IPC boundary before any registry dispatch or DB read.
const parsed = ToolApprovalRespondSchema.safeParse(rawPayload)
if (!parsed.success) {
logger.warn('Tool-approval response rejected: invalid payload', { issues: parsed.error.issues })
return { ok: false }
}
const payload = parsed.data
// Claude-Agent fast-path: live registry entry unblocks `canUseTool`.
const dispatched = application.get('AgentSessionRuntimeService').respondToolApproval(payload.approvalId, {
approved: payload.approved,
reason: payload.reason,
updatedInput: payload.updatedInput
})
if (dispatched) return { ok: true }
// MCP path: write decisions to DB, then dispatch continue-conversation when nothing is pending.
if (!payload.topicId || !payload.anchorId) {
logger.warn('Tool-approval response had no live registry entry and no anchor context', {
approvalId: payload.approvalId
})
return { ok: false }
}
// Main is the single authority for the approval mutation: the
// renderer no longer PATCHes (it sourced parts from a DB projection
// that didn't carry the overlay-only `approval-requested` part and
// raced/overwrote the persisted row). The decision is carried
// explicitly in the IPC payload; apply it here to the DB-authoritative
// parts (the original stream's terminal persistence wrote the
// `approval-requested` part onto this row) and persist.
const decision = {
approvalId: payload.approvalId,
approved: payload.approved,
...(payload.reason !== undefined && { reason: payload.reason })
}
// A stale click on a deleted message must resolve through the documented
// result shape, not throw out of the handler (getById rejects when the
// anchor is missing), consistent with the no-context branch above.
let anchor: Awaited<ReturnType<typeof messageService.getById>>
try {
anchor = await messageService.getById(payload.anchorId)
} catch {
logger.warn('Tool-approval response anchor is missing or deleted', {
approvalId: payload.approvalId,
anchorId: payload.anchorId
})
return { ok: false }
}
const beforeParts = anchor.data.parts ?? []
const targetPresent = beforeParts.some(
(p) => isToolUIPart(p) && p.state === 'approval-requested' && p.approval?.id === decision.approvalId
)
const afterParts = applyApprovalDecisions(beforeParts, [decision])
// Only write parts when this approval is present on the DB row.
// `applyApprovalDecisions` always returns a fresh array, so writing
// unconditionally would overwrite real (or not-yet-persisted) parts
// with an unchanged set. When the part is overlay-only (persist not
// landed yet), the continue dispatch below carries the decision and
// the continue provider applies it authoritatively where it reads parts.
if (targetPresent) {
await messageService.update(payload.anchorId, { data: { parts: afterParts } })
}
// Only resume once every approval on this turn is decided — a turn
// can request several tools at once; the not-yet-decided ones keep
// their cards.
const anyStillPending = afterParts.some((p) => isToolUIPart(p) && p.state === 'approval-requested')
if (anyStillPending) {
return { ok: true }
}
const aiStreamManager = application.get('AiStreamManager')
const subscriber = new WebContentsListener(event.sender, payload.topicId)
await aiStreamManager.dispatch(subscriber, {
trigger: 'continue-conversation',
topicId: payload.topicId,
parentAnchorId: payload.anchorId,
// Idempotent against the conditional write above; safety net when the part wasn't on the row.
approvalDecisions: [decision]
})
return { ok: true }
})
}
// ── Streaming chat (agent.stream) ──
/**
* Raw `UIMessageChunk` stream from `Agent.stream`. Caller (usually
* `AiStreamManager`) owns read/multicast/accumulation/terminal dispatch.
* Pre-stream errors reject the Promise; mid-stream errors come through
* the stream itself.
*/
async streamText(
request: AsInProcess<AiStreamRequest>,
extraFeatures: readonly RequestFeature[] = []
): Promise<ReadableStream<UIMessageChunk>> {
logger.info('streamText started', { chatId: request.chatId })
const signal = request.requestOptions?.signal
if (!signal) {
throw new Error('streamText requires requestOptions.signal — no AbortController was attached by the caller')
}
if (request.runtime?.kind === 'agent-session') {
return application.get('AgentSessionRuntimeService').openTurnStream({
sessionId: request.runtime.sessionId,
turnId: request.runtime.turnId,
signal
})
}
if (isAgentSessionTopic(request.chatId)) {
throw new Error(`Agent session stream ${request.chatId} requires an agent-session runtime request`)
}
const { sdkConfig, tools, plugins, system, options, model, hookParts } = await this.buildAgentParamsFor(
request,
signal,
extraFeatures
)
const preparedMessages = await resolveUIMessageFileUrls(request.messages ?? [])
const agent = new Agent({
providerId: sdkConfig.providerId,
providerSettings: sdkConfig.providerSettings,
modelId: sdkConfig.modelId,
messageId: request.messageId,
plugins,
tools,
system,
options,
hookParts: [this.analyticsHookPart(model), ...hookParts]
})
return agent.stream(preparedMessages, signal)
}
private analyticsHookPart(model: Model): Partial<AgentLoopHooks> {
let total: LanguageModelUsage = ZERO_USAGE
return {
onStepFinish: (step) => {
if (step.usage) total = mergeUsage(total, step.usage)
},
onFinish: () => this.trackUsage(model, total)
}
}
// ── Non-streaming text generation (agent.generate) ──
async generateText(
request: AsInProcess<AiGenerateRequest>,
extraFeatures: readonly RequestFeature[] = []
): Promise<AiGenerateResult> {
logger.info('generateText started', { assistantId: request.assistantId })
const signal = request.requestOptions?.signal
const { sdkConfig, tools, plugins, system, options, model, hookParts } = await this.buildAgentParamsFor(
request,
signal,
extraFeatures
)
const agent = new Agent({
providerId: sdkConfig.providerId,
providerSettings: sdkConfig.providerSettings,
modelId: sdkConfig.modelId,
plugins,
tools,
system: request.system ?? system,
options,
hookParts: [this.analyticsHookPart(model), ...hookParts]
})
// prompt and messages are mutually exclusive in AI SDK; preserve that.
return agent.generate(request.prompt ? { prompt: request.prompt } : { messages: request.messages ?? [] }, signal)
}
// ── Image generation ──
async generateImage(request: AsInProcess<AiImageRequest>): Promise<AiImageResult> {
logger.info('generateImage started', { assistantId: request.assistantId, uniqueModelId: request.uniqueModelId })
const signal = request.requestOptions?.signal
const { sdkConfig } = await this.buildAgentParamsFor(request, signal)
const promptParam = request.inputImages
? { text: request.prompt, images: request.inputImages, ...(request.mask && { mask: request.mask }) }
: request.prompt
// Map the canonical painting params onto each vendor's real image-API field
// names (negative_prompt / seed / imageConfig / …). AI SDK image models
// spread `providerOptions[<providerId>]` into the request body, so this is
// how negativePrompt/seed/steps/guidance/aspectRatio actually reach vendors.
const imageProviderOptions = buildImageProviderOptions(sdkConfig.providerId, {
negativePrompt: request.negativePrompt,
seed: request.seed !== undefined ? String(request.seed) : undefined,
numInferenceSteps: request.numInferenceSteps,
guidanceScale: request.guidanceScale,
promptEnhancement: request.promptEnhancement,
personGeneration: request.personGeneration,
quality: request.quality,
aspectRatio: request.aspectRatio,
imageSize: request.size,
providerOptions: request.providerOptions,
background: request.background,
moderation: request.moderation,
style: request.style
})
const aspectRatio = normalizeAspectRatio(request.aspectRatio)
const imageParams = {
model: sdkConfig.modelId,
prompt: promptParam,
n: request.n ?? 1,
// Client-side default: when the caller omits `size`, fall back to 1024x1024
// rather than letting the server pick its own default. Dropping this fallback
// (to truly let the server choose) is a behavior decision, not done here.
size: (request.size ?? '1024x1024') as `${number}x${number}`,
...(request.negativePrompt ? { negativePrompt: request.negativePrompt } : {}),
...(request.seed !== undefined ? { seed: request.seed } : {}),
...(request.quality ? { quality: request.quality } : {}),
...(request.numInferenceSteps !== undefined ? { numInferenceSteps: request.numInferenceSteps } : {}),
...(request.guidanceScale !== undefined ? { guidanceScale: request.guidanceScale } : {}),
...(request.promptEnhancement !== undefined ? { promptEnhancement: request.promptEnhancement } : {}),
...(aspectRatio ? { aspectRatio: aspectRatio as `${number}:${number}` } : {}),
...(Object.keys(imageProviderOptions).length > 0 ? { providerOptions: imageProviderOptions } : {}),
...(signal ? { abortSignal: signal } : {}),
experimental_download: async (downloads) => {
return Promise.all(
downloads.map(async ({ url }) => {
if (signal?.aborted) return null
const downloaded = await downloadImageAsBase64(url.toString())
if (signal?.aborted) return null
if (!downloaded) return null
return {
data: Buffer.from(downloaded.data, 'base64'),
mediaType: downloaded.media_type
}
})
)
}
}
const result = await aiCoreGenerateImage<AppProviderSettingsMap>(
sdkConfig.providerId,
sdkConfig.providerSettings,
imageParams
)
const dataUrls: Base64String[] = []
let filteredCount = 0
for (const image of result.images ?? []) {
if (image.base64) {
dataUrls.push(`data:${image.mediaType || 'image/png'};base64,${image.base64}`)
continue
}
filteredCount += 1
}
if (filteredCount > 0) {
logger.warn('Filtered invalid generated images', {
uniqueModelId: request.uniqueModelId,
providerId: sdkConfig.providerId,
modelId: sdkConfig.modelId,
filteredCount
})
}
const fileManager = application.get('FileManager')
const files = await Promise.all(dataUrls.map((data) => fileManager.createInternalEntry({ source: 'base64', data })))
return { files }
}
// ── Embedding ──
async embedMany(request: AsInProcess<AiEmbedRequest>): Promise<AiEmbedResult> {
logger.info('embedMany started', { assistantId: request.assistantId, count: request.values.length })
const signal = request.requestOptions?.signal
const { sdkConfig, model } = await this.buildAgentParamsFor(request, signal)
const result = await aiCoreEmbedMany<AppProviderSettingsMap>(sdkConfig.providerId, sdkConfig.providerSettings, {
model: sdkConfig.modelId,
values: request.values,
...(signal ? { abortSignal: signal } : {})
})
this.trackUsage(model, { inputTokens: result.usage?.tokens ?? 0, outputTokens: 0 })
return { embeddings: result.embeddings, usage: result.usage }
}
// ── Model listing ──
async listModels(request: ListModelsRequest): Promise<Partial<Model>[]> {
let providerId = request.providerId
if (!providerId && request.assistantId) {
const assistant = await assistantDataService.getById(request.assistantId).catch(() => undefined)
if (assistant?.modelId) {
providerId = parseUniqueModelId(assistant.modelId).providerId
}
}
if (!providerId) {
throw new Error('Cannot resolve providerId: not in request and assistant has no model')
}
const provider = await providerService.getByProviderId(providerId)
return listModelsFromProvider(provider, undefined, { throwOnError: request.throwOnError })
}
// ── API validation ──
/** Dispatches to `embedMany` for embedding models, `generateText` otherwise. */
async checkModel(request: AiBaseRequest & { timeout?: number }): Promise<{ latency: number }> {
const { model } = await this.getProviderAndModel(request)
const start = performance.now()
const timeout = request.timeout ?? 15000
// AbortController on timeout so the HTTP work cancels too (otherwise tokens keep burning).
const controller = new AbortController()
let timeoutHandle: ReturnType<typeof setTimeout> | undefined
const timeoutPromise = new Promise<never>((_, reject) => {
timeoutHandle = setTimeout(() => {
controller.abort(new Error('Check model timeout'))
reject(new Error('Check model timeout'))
}, timeout)
})
const probeRequest = {
...request,
requestOptions: { ...request.requestOptions, signal: controller.signal }
}
const probe = isEmbeddingModel(model)
? this.embedMany({ ...probeRequest, values: ['test'] })
: this.generateText({ ...probeRequest, system: 'test', prompt: 'hi' })
try {
await Promise.race([probe, timeoutPromise])
return { latency: performance.now() - start }
} finally {
if (timeoutHandle) clearTimeout(timeoutHandle)
}
}
// ── Shared agent parameter resolution ──
private async buildAgentParamsFor(
request: AsInProcess<AiBaseRequest> & { chatId?: string },
signal: AbortSignal | undefined,
extraFeatures: readonly RequestFeature[] = []
) {
const { provider, model, assistant } = await this.getProviderAndModel(request)
const built = await buildAgentParams({ request, signal, provider, model, assistant, extraFeatures })
return { ...built, provider, model, assistant }
}
// ── Token usage tracking ──
private trackUsage(model: Model, usage?: { inputTokens?: number; outputTokens?: number }): void {
if (!usage || !model.providerId || !model.apiModelId) return
const inputTokens = usage.inputTokens ?? 0
const outputTokens = usage.outputTokens ?? 0
if (inputTokens === 0 && outputTokens === 0) return
try {
const analyticsService = application.get('AnalyticsService')
analyticsService.trackTokenUsage({
provider: model.providerId,
model: model.apiModelId ?? model.id,
input_tokens: inputTokens,
output_tokens: outputTokens
})
} catch {
// AnalyticsService may not be activated (data collection disabled)
}
}
/** Priority: explicit `uniqueModelId` > `assistant.modelId`. */
private async getProviderAndModel(request: AiBaseRequest & { chatId?: string }) {
let assistant: Assistant | undefined
if (request.assistantId) {
assistant = await assistantDataService.getById(request.assistantId).catch(() => undefined)
}
let providerId: string | undefined
let modelId: string | undefined
if (request.uniqueModelId) {
const parsed = parseUniqueModelId(request.uniqueModelId)
providerId = parsed.providerId
modelId = parsed.modelId
} else if (assistant?.modelId) {
const parsed = parseUniqueModelId(assistant.modelId)
providerId = parsed.providerId
modelId = parsed.modelId
}
if (!providerId) throw new Error('Cannot resolve providerId: not in request and assistant has no model')
if (!modelId) throw new Error('Cannot resolve modelId: not in request and assistant has no model')
const provider = await providerService.getByProviderId(providerId)
const model = await modelService.getByKey(providerId, modelId)
return { provider, model, assistant }
}
}

View File

@@ -0,0 +1,418 @@
import { BaseService } from '@main/core/lifecycle/BaseService'
import { IpcChannel } from '@shared/IpcChannel'
import { ipcMain } from 'electron'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const mockGenerateImage = vi.fn()
const mockDownloadImageAsBase64 = vi.fn()
const mockApplicationGet = vi.fn()
vi.mock('@main/core/application', () => ({
application: {
get: mockApplicationGet
}
}))
vi.mock('@main/utils/downloadAsBase64', () => ({
downloadImageAsBase64: (...args: unknown[]) => mockDownloadImageAsBase64(...args)
}))
vi.mock('@cherrystudio/ai-core', () => ({
createAgent: vi.fn(),
embedMany: vi.fn(),
generateImage: (...args: unknown[]) => mockGenerateImage(...args)
}))
const { AiService } = await import('../AiService')
const { messageService } = await import('@main/data/services/MessageService')
/**
* Instantiate `AiService` directly (without going through the lifecycle
* container) so unit tests can drive its methods in isolation.
*/
function createService(): InstanceType<typeof AiService> {
BaseService.resetInstances()
return new (AiService as any)()
}
describe('AiService', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('routes agent-session runtime requests directly to the runtime service', async () => {
const service = createService()
const stream = new ReadableStream()
const openTurnStream = vi.fn(() => stream)
mockApplicationGet.mockReturnValue({ openTurnStream })
await expect(
service.streamText({
chatId: 'agent-session:session-1',
trigger: 'submit-message',
runtime: { kind: 'agent-session', sessionId: 'session-1', turnId: 'turn-1' },
requestOptions: { signal: new AbortController().signal }
} as any)
).resolves.toBe(stream)
expect(mockApplicationGet).toHaveBeenCalledWith('AgentSessionRuntimeService')
expect(openTurnStream).toHaveBeenCalledWith({
sessionId: 'session-1',
turnId: 'turn-1',
signal: expect.any(AbortSignal)
})
})
it('rejects agent-session streams that do not carry a runtime request', async () => {
const service = createService()
const buildAgentParamsFor = vi.spyOn(service as any, 'buildAgentParamsFor')
await expect(
service.streamText({
chatId: 'agent-session:session-1',
trigger: 'submit-message',
requestOptions: { signal: new AbortController().signal }
} as any)
).rejects.toThrow('requires an agent-session runtime request')
expect(buildAgentParamsFor).not.toHaveBeenCalled()
expect(mockApplicationGet).not.toHaveBeenCalled()
})
it('normalizes base64 and url images from ai-core generateImage', async () => {
const service = createService()
vi.spyOn(service as never, 'buildAgentParamsFor').mockResolvedValue({
sdkConfig: {
providerId: 'test-provider',
providerSettings: {},
modelId: 'test-model'
}
} as never)
mockGenerateImage.mockResolvedValue({
images: [{ base64: 'abc123', mediaType: 'image/png' }, { nonsense: true }],
providerMetadata: {
testProvider: {
images: [{ url: 'https://example.com/image.png' }]
}
}
})
mockDownloadImageAsBase64.mockResolvedValue({
data: 'url-base64',
media_type: 'image/jpeg'
})
const fileEntry = { id: 'file-1', origin: 'internal', ext: 'png', name: 'img', size: 3, createdAt: 0 }
const createInternalEntry = vi.fn().mockResolvedValue(fileEntry)
mockApplicationGet.mockImplementation((name: string) =>
name === 'FileManager' ? { createInternalEntry } : undefined
)
const result = await service.generateImage({
uniqueModelId: 'test-provider::test-model',
prompt: 'draw a cat',
n: 2,
size: '1024x1024',
negativePrompt: 'blurry',
seed: 7,
quality: 'high',
numInferenceSteps: 30,
guidanceScale: 4.5,
promptEnhancement: true,
requestOptions: { signal: new AbortController().signal }
})
expect(mockGenerateImage).toHaveBeenCalledWith(
'test-provider',
{},
expect.objectContaining({
model: 'test-model',
prompt: 'draw a cat',
n: 2,
size: '1024x1024',
negativePrompt: 'blurry',
seed: 7,
quality: 'high',
numInferenceSteps: 30,
guidanceScale: 4.5,
promptEnhancement: true
})
)
const callOptions = mockGenerateImage.mock.calls[0]?.[2]
expect(callOptions.experimental_download).toBeTypeOf('function')
const downloaded = await callOptions.experimental_download([
{
url: new URL('https://example.com/image.png'),
isUrlSupportedByModel: false
}
])
expect(mockDownloadImageAsBase64).toHaveBeenCalledWith('https://example.com/image.png')
expect(downloaded).toEqual([
{
data: Buffer.from('url-base64', 'base64'),
mediaType: 'image/jpeg'
}
])
expect(createInternalEntry).toHaveBeenCalledWith({ source: 'base64', data: 'data:image/png;base64,abc123' })
expect(result).toEqual({ files: [fileEntry] })
})
})
describe('AiService tool approval', () => {
/** A fake renderer event whose `sender` satisfies `WebContentsListener`'s constructor. */
function fakeEvent() {
return {
sender: {
id: 1,
once: vi.fn(),
isDestroyed: () => false,
send: vi.fn()
}
} as never
}
/** A minimal `approval-requested` tool UI part (passes `isToolUIPart`). */
function pendingToolPart(approvalId: string, toolName = 'mcp_write') {
return {
type: `tool-${toolName}`,
toolCallId: `tc-${approvalId}`,
state: 'approval-requested',
input: {},
approval: { id: approvalId }
}
}
/**
* Instantiate `AiService`, register its IPC handlers against the mocked
* `ipcMain`, and return the captured `Ai_ToolApproval_Respond` listener.
*/
function getApprovalHandler() {
const service = createService()
;(service as unknown as { registerIpcHandlers(): void }).registerIpcHandlers()
const call = vi
.mocked(ipcMain.handle)
.mock.calls.find(([channel]) => channel === IpcChannel.Ai_ToolApproval_Respond)
if (!call) throw new Error('Ai_ToolApproval_Respond handler was not registered')
return call[1] as (
event: unknown,
payload: {
approvalId: string
approved: boolean
reason?: string
updatedInput?: Record<string, unknown>
topicId?: string
anchorId?: string
}
) => Promise<{ ok: boolean }>
}
beforeEach(() => {
vi.clearAllMocks()
})
it('takes the Claude-Agent fast-path when the live registry dispatches the decision', async () => {
const respondToolApproval = vi.fn(() => true)
const dispatch = vi.fn()
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
const getById = vi.spyOn(messageService, 'getById')
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'agent-approval-1',
approved: true
})
expect(result).toEqual({ ok: true })
expect(respondToolApproval).toHaveBeenCalledWith('agent-approval-1', {
approved: true,
reason: undefined,
updatedInput: undefined
})
// Fast-path short-circuits before any DB read or continue dispatch.
expect(getById).not.toHaveBeenCalled()
expect(dispatch).not.toHaveBeenCalled()
})
it('returns { ok: false } when there is no live entry and no anchor context', async () => {
const respondToolApproval = vi.fn(() => false)
mockApplicationGet.mockImplementation((name: string) =>
name === 'AgentSessionRuntimeService' ? { respondToolApproval } : undefined
)
const getById = vi.spyOn(messageService, 'getById')
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'orphan-approval-1',
approved: true
// no topicId / anchorId
})
expect(result).toEqual({ ok: false })
expect(getById).not.toHaveBeenCalled()
})
it('persists the flipped parts and dispatches continue-conversation for an MCP approval present on the row', async () => {
const respondToolApproval = vi.fn(() => false)
const dispatch = vi.fn().mockResolvedValue(undefined)
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
const beforeParts = [{ type: 'text', text: 'hello' }, pendingToolPart('mcp-approval-1')]
vi.spyOn(messageService, 'getById').mockResolvedValue({ data: { parts: beforeParts } } as never)
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'mcp-approval-1',
approved: true,
topicId: 'topic-1',
anchorId: 'anchor-1'
})
expect(result).toEqual({ ok: true })
// Target part was on the row → write the flipped parts.
expect(update).toHaveBeenCalledTimes(1)
const [updatedId, updateDto] = update.mock.calls[0]
expect(updatedId).toBe('anchor-1')
const writtenParts = (updateDto as { data: { parts: Array<{ state?: string }> } }).data.parts
expect(writtenParts[1].state).toBe('approval-responded')
// Nothing left pending → resume via continue-conversation.
expect(dispatch).toHaveBeenCalledTimes(1)
expect(dispatch).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
trigger: 'continue-conversation',
topicId: 'topic-1',
parentAnchorId: 'anchor-1',
approvalDecisions: [{ approvalId: 'mcp-approval-1', approved: true }]
})
)
})
it('does not write parts when the approval is overlay-only (not present on the row) but still dispatches', async () => {
const respondToolApproval = vi.fn(() => false)
const dispatch = vi.fn().mockResolvedValue(undefined)
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
// Row carries no approval-requested part matching this approvalId.
vi.spyOn(messageService, 'getById').mockResolvedValue({
data: { parts: [{ type: 'text', text: 'hello' }] }
} as never)
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'mcp-approval-missing',
approved: false,
topicId: 'topic-1',
anchorId: 'anchor-1'
})
expect(result).toEqual({ ok: true })
// Part absent on the row → no overwrite of the persisted parts...
expect(update).not.toHaveBeenCalled()
// ...but the decision still rides the continue dispatch idempotently.
expect(dispatch).toHaveBeenCalledTimes(1)
expect(dispatch).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
trigger: 'continue-conversation',
approvalDecisions: [{ approvalId: 'mcp-approval-missing', approved: false }]
})
)
})
it('does not finalize while another approval on the turn is still pending', async () => {
const respondToolApproval = vi.fn(() => false)
const dispatch = vi.fn().mockResolvedValue(undefined)
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
// Two outstanding approvals on the same row; we only decide the first.
const beforeParts = [pendingToolPart('mcp-approval-1'), pendingToolPart('mcp-approval-2', 'mcp_read')]
vi.spyOn(messageService, 'getById').mockResolvedValue({ data: { parts: beforeParts } } as never)
const update = vi.spyOn(messageService, 'update').mockResolvedValue({} as never)
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'mcp-approval-1',
approved: true,
topicId: 'topic-1',
anchorId: 'anchor-1'
})
expect(result).toEqual({ ok: true })
// The decided part is persisted...
expect(update).toHaveBeenCalledTimes(1)
// ...but the still-pending sibling gates the resume.
expect(dispatch).not.toHaveBeenCalled()
})
it('returns { ok: false } when the anchor message is missing or deleted', async () => {
const respondToolApproval = vi.fn(() => false)
const dispatch = vi.fn().mockResolvedValue(undefined)
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
// A stale click on a deleted message: getById rejects.
const getById = vi.spyOn(messageService, 'getById').mockRejectedValue(new Error('Message not found'))
const update = vi.spyOn(messageService, 'update')
const handler = getApprovalHandler()
const result = await handler(fakeEvent(), {
approvalId: 'mcp-approval-1',
approved: true,
topicId: 'topic-1',
anchorId: 'deleted-anchor'
})
// Resolves gracefully through the documented result shape instead of throwing.
expect(result).toEqual({ ok: false })
expect(getById).toHaveBeenCalledWith('deleted-anchor')
expect(update).not.toHaveBeenCalled()
expect(dispatch).not.toHaveBeenCalled()
})
it('returns { ok: false } when the IPC payload is invalid (rejected at the boundary)', async () => {
const respondToolApproval = vi.fn(() => true)
const dispatch = vi.fn()
mockApplicationGet.mockImplementation((name: string) => {
if (name === 'AgentSessionRuntimeService') return { respondToolApproval }
if (name === 'AiStreamManager') return { dispatch }
return undefined
})
const getById = vi.spyOn(messageService, 'getById')
const handler = getApprovalHandler()
// Missing `approved` boolean and empty `approvalId` → schema rejects.
const result = await handler(fakeEvent(), { approvalId: '' } as never)
expect(result).toEqual({ ok: false })
// Rejected before any registry dispatch or DB read.
expect(respondToolApproval).not.toHaveBeenCalled()
expect(getById).not.toHaveBeenCalled()
expect(dispatch).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,12 @@
import type { Assistant, AssistantSettings } from '@shared/data/types/assistant'
import { DEFAULT_ASSISTANT_SETTINGS } from '@shared/data/types/assistant'
export function makeAssistant(
overrides: Partial<Omit<Assistant, 'settings'>> & { settings?: Partial<AssistantSettings> } = {}
): Assistant {
const { settings, ...rest } = overrides
return {
settings: { ...DEFAULT_ASSISTANT_SETTINGS, ...settings },
...rest
} as Assistant
}

View File

@@ -0,0 +1,3 @@
export { makeAssistant } from './assistant'
export { makeModel } from './model'
export { makeEndpointConfig, makeProvider } from './provider'

View File

@@ -0,0 +1,21 @@
import type { Model } from '@shared/data/types/model'
/**
* Minimal valid Model fixture for main/ai tests.
*
* Defaults satisfy ModelSchema's required fields (id, providerId, name,
* capabilities, supportsStreaming, isEnabled, isHidden). Pass overrides for
* whatever the SUT actually reads.
*/
export function makeModel(overrides: Partial<Model> = {}): Model {
return {
id: 'openai::gpt-4',
providerId: 'openai',
name: 'GPT-4',
capabilities: [],
supportsStreaming: true,
isEnabled: true,
isHidden: false,
...overrides
} as Model
}

View File

@@ -0,0 +1,26 @@
import type { EndpointConfig, Provider } from '@shared/data/types/provider'
import { DEFAULT_API_FEATURES, DEFAULT_PROVIDER_SETTINGS } from '@shared/data/types/provider'
/**
* Minimal valid Provider fixture for main/ai tests.
*
* Defaults satisfy ProviderSchema's required fields (apiKeys, authType,
* apiFeatures, settings, isEnabled). Pass overrides for whatever the SUT
* actually reads.
*/
export function makeProvider(overrides: Partial<Provider> = {}): Provider {
return {
id: 'fake',
name: 'Fake',
apiKeys: [],
authType: 'api-key',
apiFeatures: { ...DEFAULT_API_FEATURES },
settings: { ...DEFAULT_PROVIDER_SETTINGS },
isEnabled: true,
...overrides
} as Provider
}
export function makeEndpointConfig(overrides: Partial<EndpointConfig> = {}): EndpointConfig {
return { ...overrides }
}

View File

@@ -0,0 +1,847 @@
import { agentService } from '@data/services/AgentService'
import { agentSessionMessageService } from '@data/services/AgentSessionMessageService'
import { loggerService } from '@logger'
import { application } from '@main/core/application'
import { BaseService, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
import { topicNamingService } from '@main/services/TopicNamingService'
import type { Span } from '@opentelemetry/api'
import type { AgentEntity, AgentPermissionMode, UpdateAgentDto } from '@shared/data/api/schemas/agents'
import type { AgentSessionMessageEntity } from '@shared/data/types/agent'
import type { CherryUIMessage } from '@shared/data/types/message'
import { parseUniqueModelId, type UniqueModelId } from '@shared/data/types/model'
import { serializeError } from '@shared/types/error'
import type { UIMessageChunk } from 'ai'
import { v7 as uuidv7 } from 'uuid'
import { startAiTurnTrace } from '../observability'
import {
type AgentRuntimeConnection,
type AgentRuntimeEvent,
type AgentRuntimePolicyUpdate,
type AgentRuntimeTraceContext,
runtimeDriverRegistry
} from '../runtime'
import { type DispatchDecision, toolApprovalRegistry } from '../runtime/claudeCode/ToolApprovalRegistry'
import { PersistenceListener } from '../streamManager/listeners/PersistenceListener'
import { TraceFlushListener } from '../streamManager/listeners/TraceFlushListener'
import type { StreamDoneResult, StreamErrorResult, StreamListener, StreamPausedResult } from '../streamManager/types'
import { AgentSessionMessageBackend } from './persistence/AgentSessionMessageBackend'
const logger = loggerService.withContext('AgentSessionRuntimeService')
const DEFAULT_IDLE_TTL_MS = 5 * 60 * 1000
export type AgentSessionRuntimeStatus = 'active' | 'idle'
export type AgentSessionRuntimeTerminalStatus = 'success' | 'paused' | 'error'
/**
* Why an in-flight turn is being stopped — selects how much runtime state to
* tear down (see {@link STOP_POLICY}). Derived from the service's own typed
* state (`turn.interruptRequested`), never from the abort signal's `reason`, so
* a user Stop can't be misread as a steer interrupt. An abort with nothing
* requested defaults to `user-stop` — the safe-failure direction (full teardown
* closes the connection, killing the runtime/subagent).
*/
type AgentTurnStopIntent = 'interrupt' | 'user-stop'
interface TurnStopPolicy {
/** Terminal status stamped on the turn when the session survives the stop. */
turnStatus: AgentSessionRuntimeTerminalStatus
/** Tear the whole session down (connection + entry), not just the turn. */
closeSession: boolean
}
/**
* `interrupt` (steer): pause this turn but keep the connection + session so the
* queued message can open the next turn (the runtime was gracefully interrupted,
* not closed — the warm query survives). `user-stop`: tear the session down so
* `connection.close()` kills the runtime query and its subagent.
*/
const STOP_POLICY: Record<AgentTurnStopIntent, TurnStopPolicy> = {
interrupt: { turnStatus: 'paused', closeSession: false },
'user-stop': { turnStatus: 'paused', closeSession: true }
}
export interface BeginAgentSessionTurnInput {
sessionId: string
topicId: string
agentId: string
agentType: string
modelId: UniqueModelId
assistantMessageId?: string
userMessage?: AgentSessionMessageEntity
traceId?: string
rootSpanId?: string
}
export interface AgentSessionRuntimeHandle {
listeners: StreamListener[]
turnId: string
}
export interface OpenAgentSessionTurnStreamInput {
sessionId: string
turnId: string
signal: AbortSignal
}
export interface AgentSessionRuntimeSnapshot {
sessionId: string
topicId?: string
assistantMessageId?: string
status: AgentSessionRuntimeStatus
pendingMessageCount: number
lastTerminalStatus?: AgentSessionRuntimeTerminalStatus
resumeToken?: string
activeToolCount: number
interruptRequested: boolean
}
type AgentSessionTurn = {
turnId: string
assistantMessageId?: string
userMessage: AgentSessionMessageEntity
modelId: UniqueModelId
admitted: boolean
terminalStatus?: AgentSessionRuntimeTerminalStatus
controller?: ReadableStreamDefaultController<UIMessageChunk>
activeToolIds: Set<string>
interruptRequested: boolean
trace?: AgentRuntimeTraceContext
}
type AgentSessionRuntimeEntry = {
sessionId: string
topicId: string
agentId: string
agentType: string
modelId: UniqueModelId
status: AgentSessionRuntimeStatus
pendingTurns: AgentSessionMessageEntity[]
connection?: AgentRuntimeConnection
connectionLoop?: Promise<void>
/** In-flight {@link ensureConnection} promise — shared by concurrent callers so only one connect runs. */
connecting?: Promise<boolean>
currentTurn?: AgentSessionTurn
lastResumeToken?: string
lastTerminalStatus?: AgentSessionRuntimeTerminalStatus
idleTimer?: ReturnType<typeof setTimeout>
startingNextTurn?: boolean
}
class AgentSessionRuntimeTerminalListener implements StreamListener {
readonly id: string
constructor(
private readonly service: AgentSessionRuntimeService,
private readonly sessionId: string
) {
this.id = `agent-runtime:${sessionId}`
}
onChunk(): void {}
onDone(result: StreamDoneResult): void {
if (result.isTopicDone === false) return
this.service.markTurnTerminal(this.sessionId, 'success')
}
onPaused(result: StreamPausedResult): void {
if (result.isTopicDone === false) return
this.service.markTurnTerminal(this.sessionId, 'paused')
}
onError(result: StreamErrorResult): void {
if (result.isTopicDone === false) return
this.service.markTurnTerminal(this.sessionId, 'error')
}
isAlive(): boolean {
return true
}
}
@Injectable('AgentSessionRuntimeService')
@ServicePhase(Phase.WhenReady)
export class AgentSessionRuntimeService extends BaseService {
private readonly entries = new Map<string, AgentSessionRuntimeEntry>()
protected async onInit(): Promise<void> {
// Resolve agent-session assistant rows a prior main-process crash left `pending` — at boot the
// in-memory entry map is empty, so every such row is stale. Mirrors AiStreamManager's chat
// reconcile so both message tables are settled on restart (neither stays a frozen "thinking"
// bubble); agent sessions additionally recover conversation context via the resume token.
await this.reconcileStalePendingMessages()
this.registerDisposable(
agentService.onAgentUpdated(({ agentId, updates, agent }) => {
void this.handleAgentUpdated(agentId, updates, agent).catch((error) => {
logger.warn('Failed to apply live agent policy update', { agentId, error })
})
})
)
}
private async reconcileStalePendingMessages(): Promise<void> {
try {
const staleIds = await agentSessionMessageService.findPendingAssistantMessageIds()
if (staleIds.length === 0) return
logger.info('Reconciling crash-orphaned pending agent-session messages', { count: staleIds.length })
await agentSessionMessageService.markMessagesError(staleIds)
} catch (error) {
logger.error('Failed to reconcile stale pending agent-session messages', { error })
}
}
beginTurn(input: BeginAgentSessionTurnInput): AgentSessionRuntimeHandle {
const turnId = crypto.randomUUID()
const userMessage = input.userMessage ?? createSyntheticUserMessage(input.sessionId)
const existing = this.entries.get(input.sessionId)
const turn: AgentSessionTurn = {
turnId,
assistantMessageId: input.assistantMessageId,
userMessage,
modelId: input.modelId,
admitted: false,
activeToolIds: new Set(),
interruptRequested: false,
trace: this.createTraceContext(input, turnId, input.traceId, input.rootSpanId)
}
if (existing?.status === 'idle') {
this.clearIdleTimer(existing)
existing.pendingTurns = []
existing.topicId = input.topicId
existing.agentId = input.agentId
existing.agentType = input.agentType
existing.modelId = input.modelId
existing.status = 'active'
existing.currentTurn = turn
return {
listeners: [
this.createPersistenceListener(existing, userMessage),
new AgentSessionRuntimeTerminalListener(this, input.sessionId),
new TraceFlushListener(input.topicId)
],
turnId
}
}
if (existing) this.closeSession(input.sessionId)
const entry: AgentSessionRuntimeEntry = {
sessionId: input.sessionId,
topicId: input.topicId,
agentId: input.agentId,
agentType: input.agentType,
modelId: input.modelId,
status: 'active',
pendingTurns: [],
currentTurn: turn
}
this.entries.set(input.sessionId, entry)
return {
listeners: [
this.createPersistenceListener(entry, userMessage),
new AgentSessionRuntimeTerminalListener(this, input.sessionId),
new TraceFlushListener(input.topicId)
],
turnId
}
}
async applyAgentPolicyUpdate(agentId: string, update: AgentRuntimePolicyUpdate): Promise<void> {
const updates: Array<Promise<boolean> | boolean> = []
for (const entry of this.entries.values()) {
if (entry.agentId !== agentId || !entry.connection?.applyPolicyUpdate) continue
updates.push(entry.connection.applyPolicyUpdate(update))
}
await Promise.allSettled(updates)
}
private async handleAgentUpdated(agentId: string, updates: UpdateAgentDto, agent: AgentEntity): Promise<void> {
const configuration = updates.configuration as { permission_mode?: unknown } | undefined
const hasPermissionModeUpdate =
configuration !== undefined && Object.prototype.hasOwnProperty.call(configuration, 'permission_mode')
if (hasPermissionModeUpdate) {
await this.applyAgentPolicyUpdate(agentId, {
type: 'permission-mode',
permissionMode: configuration.permission_mode as AgentPermissionMode | undefined
})
}
if (
Object.prototype.hasOwnProperty.call(updates, 'allowedTools') ||
Object.prototype.hasOwnProperty.call(updates, 'mcps')
) {
await this.applyAgentPolicyUpdate(agentId, { type: 'tool-policy', agent })
}
}
openTurnStream(input: OpenAgentSessionTurnStreamInput): ReadableStream<UIMessageChunk> {
const entry = this.entries.get(input.sessionId)
const turn = entry?.currentTurn
if (!entry || !turn || turn.turnId !== input.turnId) {
throw new Error(`No active agent runtime turn ${input.turnId} for session ${input.sessionId}`)
}
return new ReadableStream<UIMessageChunk>({
start: async (controller) => {
try {
this.clearIdleTimer(entry)
turn.controller = controller
const onAbort = () => this.stopTurn(entry, turn.interruptRequested ? 'interrupt' : 'user-stop')
if (input.signal.aborted) {
onAbort()
return
} else {
input.signal.addEventListener('abort', onAbort, { once: true })
}
controller.enqueue({ type: 'start' })
const connected = await this.ensureConnection(entry)
if (!connected || !this.isCurrentEntry(entry) || turn.terminalStatus) return
await this.admitTurn(entry, turn)
} catch (error) {
controller.error(error)
}
},
cancel: () => {
this.closeCurrentTurn(entry, 'paused')
}
})
}
enqueueUserMessage(sessionId: string, message: AgentSessionMessageEntity): void {
const entry = this.entries.get(sessionId)
if (!entry) return
entry.pendingTurns.push(message)
entry.status = 'active'
this.clearIdleTimer(entry)
const turn = entry.currentTurn
if (!turn || turn.terminalStatus) {
this.scheduleNextTurn(entry)
return
}
if (turn.activeToolIds.size > 0) return
queueMicrotask(() => {
const latest = this.entries.get(sessionId)
if (!latest?.currentTurn || latest.currentTurn.terminalStatus) {
if (latest) this.scheduleNextTurn(latest)
return
}
this.requestInterruptWhenSafe(latest)
})
}
markTurnTerminal(sessionId: string, status: AgentSessionRuntimeTerminalStatus): void {
const entry = this.entries.get(sessionId)
if (!entry) return
entry.status = 'idle'
entry.lastTerminalStatus = status
if (entry.currentTurn) entry.currentTurn.terminalStatus = status
if (this.shouldCloseConnectionAfterTurn(entry)) {
// close() may be async on some drivers; swallow rejection so it can't become unhandled.
void Promise.resolve(this.closeConnection(entry)?.close()).catch((error) =>
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
)
}
if (entry.pendingTurns.length > 0) {
this.scheduleNextTurn(entry)
} else {
this.refreshIdleTimer(entry)
}
}
closeSession(sessionId: string): void {
const entry = this.entries.get(sessionId)
if (!entry) return
this.entries.delete(sessionId)
this.closeEntry(entry)
}
/**
* Whether the session has a turn in flight or about to start: a non-terminal current turn,
* a next-turn drain in progress (`startingNextTurn`), or queued follow-ups. The dispatcher
* uses this — NOT `AiStreamManager.hasLiveStream` — to decide enqueue-vs-begin, because
* `hasLiveStream` is false during the inter-turn drain window while the entry is still
* mid-transition; a fresh dispatch trusting `hasLiveStream` there would clobber the drain via
* `beginTurn`.
*/
isSessionBusy(sessionId: string): boolean {
const entry = this.entries.get(sessionId)
if (!entry) return false
return (
entry.startingNextTurn === true ||
entry.pendingTurns.length > 0 ||
(entry.currentTurn !== undefined && entry.currentTurn.terminalStatus === undefined)
)
}
inspect(sessionId: string): AgentSessionRuntimeSnapshot | undefined {
const entry = this.entries.get(sessionId)
if (!entry) return undefined
const turn = entry.currentTurn
return {
sessionId: entry.sessionId,
topicId: entry.topicId,
assistantMessageId: turn?.assistantMessageId,
status: entry.status,
pendingMessageCount: entry.pendingTurns.length,
lastTerminalStatus: entry.lastTerminalStatus,
resumeToken: entry.lastResumeToken,
activeToolCount: turn?.activeToolIds.size ?? 0,
interruptRequested: turn?.interruptRequested ?? false
}
}
/**
* Resolve a Claude `canUseTool` approval that was registered against the live
* driver session. Returns `false` if no live entry matches — the caller
* falls back to MCP/DB path.
*/
respondToolApproval(approvalId: string, decision: DispatchDecision): boolean {
return toolApprovalRegistry.dispatch(approvalId, decision)
}
protected onStop(): void {
this.closeAll()
toolApprovalRegistry.clear('agent-session-runtime-stop')
}
protected onDestroy(): void {
this.closeAll()
toolApprovalRegistry.clear('agent-session-runtime-destroy')
}
private isCurrentEntry(entry: AgentSessionRuntimeEntry): boolean {
return this.entries.get(entry.sessionId) === entry
}
private async ensureConnection(entry: AgentSessionRuntimeEntry): Promise<boolean> {
if (!this.isCurrentEntry(entry)) return false
if (entry.connection) return true
// Share a single in-flight connect across concurrent callers so two streams opening at once
// can't each spin up a connection (the second would leak/clobber the first).
if (entry.connecting) return entry.connecting
const connecting = this.connect(entry).finally(() => {
if (entry.connecting === connecting) entry.connecting = undefined
})
entry.connecting = connecting
return connecting
}
private async connect(entry: AgentSessionRuntimeEntry): Promise<boolean> {
const driver = runtimeDriverRegistry.getAgentSessionDriver(entry.agentType)
if (!driver) throw new Error(`Unsupported agent runtime type: ${entry.agentType}`)
await this.hydrateResumeToken(entry)
if (!this.isCurrentEntry(entry)) return false
const connection = await driver.connect({
sessionId: entry.sessionId,
agentId: entry.agentId,
modelId: entry.modelId,
resumeToken: entry.lastResumeToken,
trace: entry.currentTurn?.trace
})
if (!this.isCurrentEntry(entry)) {
void Promise.resolve(connection.close()).catch((error) =>
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
)
return false
}
entry.connection = connection
entry.connectionLoop = this.runConnectionLoop(entry, connection).finally(() => {
if (entry.connection === connection) entry.connection = undefined
if (entry.connectionLoop) entry.connectionLoop = undefined
})
return true
}
private async hydrateResumeToken(entry: AgentSessionRuntimeEntry): Promise<void> {
if (entry.lastResumeToken) return
const runtimeResumeToken = await agentSessionMessageService.getLastRuntimeResumeToken(entry.sessionId)
if (runtimeResumeToken) entry.lastResumeToken = runtimeResumeToken
}
private async runConnectionLoop(entry: AgentSessionRuntimeEntry, connection: AgentRuntimeConnection): Promise<void> {
try {
for await (const event of connection.events) {
this.handleRuntimeEvent(entry, event)
}
} catch (error) {
this.handleRuntimeError(entry, error)
}
}
private handleRuntimeEvent(entry: AgentSessionRuntimeEntry, event: AgentRuntimeEvent): void {
switch (event.type) {
case 'resume-token':
entry.lastResumeToken = event.token
break
case 'chunk': {
const turn = entry.currentTurn
if (turn?.controller && !turn.terminalStatus) this.enqueueTurnChunk(entry, turn, event.chunk)
break
}
case 'turn-complete':
this.closeCurrentTurn(entry, 'success')
break
case 'error':
this.handleRuntimeError(entry, event.error)
break
}
}
private handleRuntimeError(entry: AgentSessionRuntimeEntry, error: unknown): void {
const turn = entry.currentTurn
if (turn?.controller && !turn.terminalStatus) {
turn.controller.error(error)
// Mark terminal synchronously: the listener's markTurnTerminal arrives async (after the
// stream error propagates), so a trailing `chunk` event in the same connection loop would
// otherwise hit enqueueTurnChunk and throw on the now-errored controller.
turn.terminalStatus = 'error'
} else if (isAbortError(error)) {
// Expected when a turn was interrupted/closed — the connection ending is not a fault.
logger.warn('Agent runtime connection ended without an active turn', { sessionId: entry.sessionId, error })
} else {
// No turn to surface this on, so a real runtime failure would otherwise vanish — log it loudly
// so the next reconnect-into-the-same-failure is at least traceable.
logger.error('Agent runtime connection ended without an active turn', { sessionId: entry.sessionId, error })
}
}
private async admitTurn(entry: AgentSessionRuntimeEntry, turn: AgentSessionTurn): Promise<void> {
if (!this.isCurrentEntry(entry) || entry.currentTurn !== turn || turn.terminalStatus) return
if (turn.admitted) return
turn.admitted = true
entry.status = 'active'
await entry.connection?.send({ message: turn.userMessage })
if (entry.pendingTurns.length > 0) {
queueMicrotask(() => this.requestInterruptWhenSafe(entry))
}
}
private enqueueTurnChunk(entry: AgentSessionRuntimeEntry, turn: AgentSessionTurn, chunk: UIMessageChunk): void {
const toolChunk = chunk as { type?: string; toolCallId?: string }
if ((toolChunk.type === 'tool-input-start' || toolChunk.type === 'tool-input-available') && toolChunk.toolCallId) {
turn.activeToolIds.add(toolChunk.toolCallId)
} else if (
(toolChunk.type === 'tool-output-available' ||
toolChunk.type === 'tool-output-error' ||
toolChunk.type === 'tool-output-denied') &&
toolChunk.toolCallId
) {
turn.activeToolIds.delete(toolChunk.toolCallId)
}
turn.controller?.enqueue(chunk)
if (turn.activeToolIds.size === 0 && entry.pendingTurns.length > 0) this.requestInterruptWhenSafe(entry)
}
private requestInterruptWhenSafe(entry: AgentSessionRuntimeEntry): void {
const turn = entry.currentTurn
if (!turn || turn.terminalStatus || !turn.admitted || turn.interruptRequested) return
const canInterrupt = entry.connection?.canInterruptNow?.() ?? turn.activeToolIds.size === 0
if (!canInterrupt) return
turn.interruptRequested = true
this.interruptCurrentTurn(entry)
}
private interruptCurrentTurn(entry: AgentSessionRuntimeEntry): void {
const turn = entry.currentTurn
if (!turn || turn.terminalStatus) return
void entry.connection?.interrupt?.().catch((error) => {
logger.warn('Agent runtime interrupt failed', { sessionId: entry.sessionId, error })
})
application.get('AiStreamManager').pauseRuntimeTurn(entry.topicId, 'agent-runtime-interrupt')
}
private stopTurn(entry: AgentSessionRuntimeEntry, intent: AgentTurnStopIntent): void {
const policy = STOP_POLICY[intent]
if (policy.closeSession) {
this.closeSession(entry.sessionId)
return
}
this.closeCurrentTurn(entry, policy.turnStatus)
}
private closeCurrentTurn(entry: AgentSessionRuntimeEntry, status: AgentSessionRuntimeTerminalStatus): void {
const turn = entry.currentTurn
if (!turn || turn.terminalStatus) return
turn.terminalStatus = status
try {
turn.controller?.close()
} catch {
// Already closed by the stream reader.
}
turn.controller = undefined
turn.activeToolIds.clear()
}
private scheduleNextTurn(entry: AgentSessionRuntimeEntry): void {
if (entry.startingNextTurn) return
entry.startingNextTurn = true
// Keep `startingNextTurn` set for the WHOLE drain — `startNextTurn` spans a DB round-trip,
// and `isSessionBusy` relies on this flag so a concurrent dispatch landing in the inter-turn
// window enqueues instead of beginning a clobbering fresh turn. Clear it only once the drain
// settles (turn established, bailed, or errored).
queueMicrotask(() => {
void this.startNextTurn(entry)
.catch((error) => {
logger.error('Failed to start next agent runtime turn', { sessionId: entry.sessionId, error })
})
.finally(() => {
entry.startingNextTurn = false
})
})
}
private async startNextTurn(entry: AgentSessionRuntimeEntry): Promise<void> {
const nextMessage = entry.pendingTurns.shift()
if (!nextMessage) {
this.refreshIdleTimer(entry)
return
}
const { rootSpan, traceId, rootSpanId } = this.startRuntimeRootSpan(entry)
let assistantMessage: Awaited<ReturnType<typeof agentSessionMessageService.saveMessage>>
try {
assistantMessage = await agentSessionMessageService.saveMessage({
sessionId: entry.sessionId,
message: {
role: 'assistant',
status: 'pending',
data: { parts: [] },
modelId: entry.modelId,
traceId
}
})
} catch (error) {
// The placeholder save failed, so there is no assistant row to drive to `error` and no
// point re-queuing the message — the retry would just fail the same way, and a re-queued
// message is silently cleared by the idle TTL anyway. Instead surface the failure to the
// live renderer and settle the turn so the session doesn't sit idle on a doomed message.
rootSpan.end()
application.get('AiStreamManager').broadcastTopicError(entry.topicId, entry.modelId, serializeError(error))
this.markTurnTerminal(entry.sessionId, 'error')
return
}
// The DB save above yields the event loop; the session may have been torn down
// (shutdown / a fresh beginTurn) in the meantime. Re-check before mutating the entry,
// mirroring every other async method here — otherwise a dead entry gets resurrected
// into a doomed runtime turn with no backing agent connection.
if (!this.isCurrentEntry(entry)) {
rootSpan.end()
return
}
const assistantMessageId = assistantMessage.id
const turnId = crypto.randomUUID()
entry.currentTurn = {
turnId,
assistantMessageId,
userMessage: nextMessage,
modelId: entry.modelId,
admitted: false,
activeToolIds: new Set(),
interruptRequested: false,
trace: this.createTraceContext(entry, turnId, traceId, rootSpanId)
}
application.get('AiStreamManager').startRuntimeTurn({
topicId: entry.topicId,
modelId: entry.modelId,
rootSpan,
request: {
chatId: entry.topicId,
trigger: 'submit-message',
messageId: assistantMessageId,
messages: createRuntimeSeedMessages(nextMessage, assistantMessageId),
runtime: { kind: 'agent-session', sessionId: entry.sessionId, turnId }
},
listeners: [
this.createPersistenceListener(entry, nextMessage),
new AgentSessionRuntimeTerminalListener(this, entry.sessionId),
new TraceFlushListener(entry.topicId)
]
})
}
private startRuntimeRootSpan(entry: AgentSessionRuntimeEntry): {
rootSpan: Span
traceId: string
rootSpanId: string
} {
const turnTrace = startAiTurnTrace(
'chat.turn',
{
attributes: {
'cs.topic_id': entry.topicId,
'cs.trigger': 'submit-message',
'cs.model_id': entry.modelId,
'cs.role': 'assistant',
'cs.agent_id': entry.agentId,
'cs.session_id': entry.sessionId
}
},
{ topicId: entry.topicId, modelName: parseUniqueModelId(entry.modelId).modelId }
)
return { rootSpan: turnTrace.rootSpan, traceId: turnTrace.traceId, rootSpanId: turnTrace.rootSpanId }
}
private createTraceContext(
input: Pick<BeginAgentSessionTurnInput, 'topicId' | 'sessionId' | 'modelId'>,
turnId: string,
traceId?: string,
rootSpanId?: string
): AgentRuntimeTraceContext | undefined {
if (!traceId || !rootSpanId) return undefined
return {
topicId: input.topicId,
traceId,
rootSpanId,
sessionId: input.sessionId,
turnId,
modelName: parseUniqueModelId(input.modelId).modelId
}
}
private shouldCloseConnectionAfterTurn(entry: AgentSessionRuntimeEntry): boolean {
return entry.connection?.shouldCloseAfterTurn?.() ?? false
}
private createPersistenceListener(
entry: AgentSessionRuntimeEntry,
userMessage: AgentSessionMessageEntity
): StreamListener {
const userText = extractMessageText(userMessage)
return new PersistenceListener({
topicId: entry.topicId,
modelId: entry.modelId,
backend: new AgentSessionMessageBackend({
sessionId: entry.sessionId,
modelId: entry.modelId,
runtimeResumeToken: () => entry.lastResumeToken,
afterPersist: async (finalMessage) => {
await topicNamingService.maybeRenameAgentSession(entry.agentId, entry.sessionId, userText, finalMessage)
}
}),
onPersistFailed: (error) =>
application.get('AiStreamManager').broadcastTopicError(entry.topicId, entry.modelId, error)
})
}
private refreshIdleTimer(entry: AgentSessionRuntimeEntry): void {
this.clearIdleTimer(entry)
entry.idleTimer = setTimeout(() => {
const { sessionId, agentType, lastResumeToken } = entry
this.closeSession(sessionId)
if (lastResumeToken) {
runtimeDriverRegistry.getAgentSessionDriver(agentType)?.onSessionIdle?.(sessionId)
}
}, DEFAULT_IDLE_TTL_MS)
entry.idleTimer.unref?.()
}
private clearIdleTimer(entry: AgentSessionRuntimeEntry): void {
if (entry.idleTimer) {
clearTimeout(entry.idleTimer)
entry.idleTimer = undefined
}
}
private closeAll(): void {
for (const sessionId of [...this.entries.keys()]) {
this.closeSession(sessionId)
}
}
private closeEntry(entry: AgentSessionRuntimeEntry): void {
this.clearIdleTimer(entry)
this.closeCurrentTurn(entry, 'paused')
entry.pendingTurns = []
const connection = this.closeConnection(entry)
entry.currentTurn = undefined
entry.startingNextTurn = false
void Promise.resolve(connection?.close()).catch((error) =>
logger.warn('Agent runtime connection close failed', { sessionId: entry.sessionId, error })
)
}
private closeConnection(entry: AgentSessionRuntimeEntry): AgentRuntimeConnection | undefined {
const connection = entry.connection
entry.connection = undefined
entry.connectionLoop = undefined
return connection
}
}
function isAbortError(error: unknown): boolean {
return !!error && typeof error === 'object' && 'name' in error && (error as { name: unknown }).name === 'AbortError'
}
function createRuntimeSeedMessages(
userMessage: AgentSessionMessageEntity,
assistantMessageId: string
): CherryUIMessage[] {
return [
{
id: userMessage.id,
role: 'user',
parts: userMessage.data?.parts ?? []
},
{
id: assistantMessageId,
role: 'assistant',
parts: []
}
] as CherryUIMessage[]
}
function createSyntheticUserMessage(sessionId: string): AgentSessionMessageEntity {
const now = new Date().toISOString()
return {
id: uuidv7(),
sessionId,
role: 'user',
data: { parts: [] },
status: 'success',
searchableText: '',
modelId: null,
modelSnapshot: null,
traceId: null,
stats: null,
runtimeResumeToken: null,
createdAt: now,
updatedAt: now
}
}
function extractMessageText(message: AgentSessionMessageEntity): string {
return (
message.data?.parts
?.filter((part): part is { type: 'text'; text: string } => part.type === 'text' && 'text' in part)
.map((part) => part.text)
.join('\n') ?? ''
)
}

View File

@@ -0,0 +1,940 @@
import { BaseService } from '@main/core/lifecycle/BaseService'
import { mockMainLoggerService } from '@test-mocks/MainLoggerService'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const mocks = vi.hoisted(() => ({
saveMessage: vi.fn(),
getLastRuntimeResumeToken: vi.fn(),
findPendingAssistantMessageIds: vi.fn(),
markMessagesError: vi.fn(),
maybeRenameAgentSession: vi.fn(),
applicationGet: vi.fn(),
startRuntimeTurn: vi.fn(),
pauseRuntimeTurn: vi.fn(),
broadcastTopicError: vi.fn(),
spanCacheSetTopicId: vi.fn()
}))
vi.mock('@data/services/AgentSessionMessageService', () => ({
agentSessionMessageService: {
saveMessage: mocks.saveMessage,
getLastRuntimeResumeToken: mocks.getLastRuntimeResumeToken,
findPendingAssistantMessageIds: mocks.findPendingAssistantMessageIds,
markMessagesError: mocks.markMessagesError
}
}))
vi.mock('@main/services/TopicNamingService', () => ({
topicNamingService: { maybeRenameAgentSession: mocks.maybeRenameAgentSession }
}))
vi.mock('@main/core/application', () => ({
application: { get: mocks.applicationGet }
}))
const { AgentSessionRuntimeService } = await import('../AgentSessionRuntimeService')
const { runtimeDriverRegistry } = await import('../../runtime')
const baseTurnInput = {
sessionId: 'session-1',
topicId: 'agent-session:session-1',
agentId: 'agent-1',
agentType: 'test-runtime',
modelId: 'claude-code::claude-sonnet-4-5' as any,
assistantMessageId: 'assistant-1'
}
function userMessage(id: string) {
return {
id,
topicId: 'agent-session:session-1',
parentId: null,
role: 'user',
data: { parts: [{ type: 'text', text: 'hello' }] },
status: 'success',
createdAt: '',
updatedAt: ''
} as any
}
function terminalListener(handle: { listeners: any[] }) {
const listener = handle.listeners.find((item) => item.id === 'agent-runtime:session-1')
if (!listener) throw new Error('terminal listener missing')
return listener
}
function persistenceListener(handle: { listeners: any[] }) {
const listener = handle.listeners.find((item) => String(item.id).startsWith('persistence:agents-db:'))
if (!listener) throw new Error('persistence listener missing')
return listener
}
function getEntry(service: InstanceType<typeof AgentSessionRuntimeService>) {
return (service as any).entries.get('session-1')
}
function createAsyncQueue<T>() {
const items: T[] = []
const waiters: Array<(value: IteratorResult<T>) => void> = []
return {
push(item: T) {
const waiter = waiters.shift()
if (waiter) waiter({ value: item, done: false })
else items.push(item)
},
iterable: {
[Symbol.asyncIterator](): AsyncIterator<T> {
return {
next: () => {
const item = items.shift()
if (item) return Promise.resolve({ value: item, done: false })
return new Promise<IteratorResult<T>>((resolve) => waiters.push(resolve))
}
}
}
}
}
}
function createDeferred<T>() {
let resolve!: (value: T) => void
let reject!: (reason?: unknown) => void
const promise = new Promise<T>((resolvePromise, rejectPromise) => {
resolve = resolvePromise
reject = rejectPromise
})
return { promise, resolve, reject }
}
describe('AgentSessionRuntimeService', () => {
beforeEach(() => {
BaseService.resetInstances()
runtimeDriverRegistry.clearForTest()
vi.clearAllMocks()
mocks.saveMessage.mockImplementation(async ({ message }) => ({
...message,
id: message.id ?? 'generated-message-id'
}))
mocks.getLastRuntimeResumeToken.mockResolvedValue(null)
mocks.findPendingAssistantMessageIds.mockResolvedValue([])
mocks.markMessagesError.mockResolvedValue(undefined)
mocks.applicationGet.mockImplementation((name: string) => {
if (name === 'AiStreamManager') {
return {
startRuntimeTurn: mocks.startRuntimeTurn,
pauseRuntimeTurn: mocks.pauseRuntimeTurn,
broadcastTopicError: mocks.broadcastTopicError
}
}
if (name === 'SpanCacheService') return { setTopicId: mocks.spanCacheSetTopicId }
throw new Error(`Unexpected application.get(${name})`)
})
})
describe('isSessionBusy — inter-turn drain window (issue ①)', () => {
it('is false with no entry and true while a turn is live', () => {
const service = new AgentSessionRuntimeService()
expect(service.isSessionBusy('session-1')).toBe(false)
service.beginTurn(baseTurnInput)
expect(service.isSessionBusy('session-1')).toBe(true)
})
it('is false once a turn settles with no queued follow-ups', () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
service.markTurnTerminal('session-1', 'success')
expect(service.isSessionBusy('session-1')).toBe(false)
})
it('stays busy throughout the next-turn drain, closing the clobber window', async () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
service.enqueueUserMessage('session-1', userMessage('user-2'))
// Hold the drain's assistant-placeholder save so we can observe the in-flight window.
const deferred = createDeferred<any>()
mocks.saveMessage.mockImplementationOnce(() => deferred.promise)
service.markTurnTerminal('session-1', 'success') // current turn → terminal, schedules the drain
await new Promise((resolve) => setTimeout(resolve, 0)) // flush microtasks → drain parks on saveMessage
const entry = getEntry(service)
// The bug window: the queued message was shifted (pendingTurns empty) and the old turn is
// terminal — pre-fix nothing reported the session busy here.
expect(entry.pendingTurns.length).toBe(0)
expect(entry.currentTurn.terminalStatus).toBe('success')
expect(entry.startingNextTurn).toBe(true) // flag now spans the whole drain
expect(service.isSessionBusy('session-1')).toBe(true)
deferred.resolve({ id: 'assistant-2' })
await new Promise((resolve) => setTimeout(resolve, 0)) // drain completes → fresh live turn
expect(service.isSessionBusy('session-1')).toBe(true)
expect(getEntry(service).startingNextTurn).toBe(false)
})
it('does not resurrect a session torn down during the next-turn placeholder save (REGRESSION agent-session-1)', async () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
service.enqueueUserMessage('session-1', userMessage('user-2'))
// Hold the drain's placeholder save so we can tear the session down mid-await.
const deferred = createDeferred<any>()
mocks.saveMessage.mockImplementationOnce(() => deferred.promise)
service.markTurnTerminal('session-1', 'success') // schedules the drain → parks on saveMessage
await new Promise((resolve) => setTimeout(resolve, 0))
const startCallsBefore = mocks.startRuntimeTurn.mock.calls.length
// Session is torn down (shutdown / a fresh beginTurn) while the save is still in flight.
service.closeSession('session-1')
deferred.resolve({ id: 'assistant-2' })
await new Promise((resolve) => setTimeout(resolve, 0))
// The dead entry must NOT be resurrected into a runtime turn.
expect(mocks.startRuntimeTurn.mock.calls.length).toBe(startCallsBefore)
expect(getEntry(service)).toBeUndefined()
})
})
describe('reconcileStalePendingMessages — boot crash recovery', () => {
it('marks crash-orphaned pending assistant messages as errored on init', async () => {
mocks.findPendingAssistantMessageIds.mockResolvedValue(['stale-1', 'stale-2'])
const service = new AgentSessionRuntimeService()
await (service as any).onInit()
expect(mocks.findPendingAssistantMessageIds).toHaveBeenCalledOnce()
expect(mocks.markMessagesError).toHaveBeenCalledWith(['stale-1', 'stale-2'])
})
it('does not mark anything when there are no stale messages', async () => {
mocks.findPendingAssistantMessageIds.mockResolvedValue([])
const service = new AgentSessionRuntimeService()
await (service as any).onInit()
expect(mocks.markMessagesError).not.toHaveBeenCalled()
})
it('logs and does not rethrow when the reconcile lookup throws, so boot is not blocked', async () => {
const failure = new Error('db down')
mocks.findPendingAssistantMessageIds.mockRejectedValue(failure)
const service = new AgentSessionRuntimeService()
await expect((service as any).onInit()).resolves.toBeUndefined()
expect(mocks.markMessagesError).not.toHaveBeenCalled()
expect(mockMainLoggerService.error).toHaveBeenCalledWith(
'Failed to reconcile stale pending agent-session messages',
{ error: failure }
)
})
})
it('creates an active runtime with a session-level pending queue', () => {
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn(baseTurnInput)
service.enqueueUserMessage('session-1', userMessage('user-2'))
expect(terminalListener(handle).id).toBe('agent-runtime:session-1')
expect(persistenceListener(handle).id).toContain('persistence:agents-db:agent-session:session-1')
expect(service.inspect('session-1')).toMatchObject({
sessionId: 'session-1',
topicId: 'agent-session:session-1',
assistantMessageId: 'assistant-1',
status: 'active',
pendingMessageCount: 1,
lastTerminalStatus: undefined,
activeToolCount: 0,
interruptRequested: false
})
})
it('marks the runtime idle when the terminal listener observes done', () => {
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn(baseTurnInput)
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
expect(service.inspect('session-1')).toMatchObject({
status: 'idle',
pendingMessageCount: 0,
lastTerminalStatus: 'success'
})
})
it('hands an idle session with a resume token to the driver onSessionIdle hook', () => {
vi.useFakeTimers()
try {
const onSessionIdle = vi.fn()
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn(),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([]),
onSessionIdle
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn(baseTurnInput)
getEntry(service).lastResumeToken = 'resume-1'
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
vi.advanceTimersByTime(5 * 60 * 1000)
expect(onSessionIdle).toHaveBeenCalledWith('session-1')
expect(service.inspect('session-1')).toBeUndefined()
} finally {
vi.useRealTimers()
}
})
it('does not call onSessionIdle for an idle session without a resume token', () => {
vi.useFakeTimers()
try {
const onSessionIdle = vi.fn()
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn(),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([]),
onSessionIdle
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn(baseTurnInput)
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
vi.advanceTimersByTime(5 * 60 * 1000)
expect(onSessionIdle).not.toHaveBeenCalled()
expect(service.inspect('session-1')).toBeUndefined()
} finally {
vi.useRealTimers()
}
})
it('reuses an idle runtime for the next fresh turn', () => {
const service = new AgentSessionRuntimeService()
const first = service.beginTurn(baseTurnInput)
const entry = getEntry(service)
const connection = { close: vi.fn(), send: vi.fn(), events: [] }
entry.lastResumeToken = 'resume-1'
entry.connection = connection
void terminalListener(first).onDone({ status: 'success', isTopicDone: true })
const second = service.beginTurn({
...baseTurnInput,
assistantMessageId: 'assistant-2',
userMessage: userMessage('user-2')
})
expect(second).not.toBe(first)
expect(getEntry(service).connection).toBe(connection)
expect(getEntry(service).pendingTurns).toEqual([])
expect(service.inspect('session-1')).toMatchObject({
assistantMessageId: 'assistant-2',
status: 'active',
pendingMessageCount: 0,
resumeToken: 'resume-1'
})
})
it('ignores per-execution terminal events until the topic is done', () => {
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn(baseTurnInput)
void terminalListener(handle).onPaused({ status: 'paused', isTopicDone: false })
expect(service.inspect('session-1')).toMatchObject({
status: 'active',
lastTerminalStatus: undefined
})
})
it('clears the runtime and closes the connection on closeSession', () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
const connection = { close: vi.fn(), send: vi.fn(), events: [] }
const entry = getEntry(service)
entry.connection = connection
entry.connectionLoop = Promise.resolve()
entry.startingNextTurn = true
service.closeSession('session-1')
expect(connection.close).toHaveBeenCalled()
expect(entry.connection).toBeUndefined()
expect(entry.connectionLoop).toBeUndefined()
expect(entry.currentTurn).toBeUndefined()
expect(entry.startingNextTurn).toBe(false)
expect(service.inspect('session-1')).toBeUndefined()
})
it('does not throw and logs a warning when the connection close rejects on closeSession (REGRESSION agent-session-5)', async () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
const closeError = new Error('close failed')
const connection = { close: vi.fn().mockRejectedValue(closeError), send: vi.fn(), events: [] }
const entry = getEntry(service)
entry.connection = connection
entry.connectionLoop = Promise.resolve()
expect(() => service.closeSession('session-1')).not.toThrow()
expect(connection.close).toHaveBeenCalled()
expect(service.inspect('session-1')).toBeUndefined()
await vi.waitFor(() =>
expect(mockMainLoggerService.warn).toHaveBeenCalledWith(
'Agent runtime connection close failed',
expect.objectContaining({ sessionId: 'session-1', error: closeError })
)
)
})
it('persists assistant turns with the latest resume token', async () => {
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
getEntry(service).lastResumeToken = 'resume-1'
await persistenceListener(handle).onDone({
status: 'success',
isTopicDone: true,
finalMessage: { id: 'assistant-1', role: 'assistant', parts: [{ type: 'text', text: 'hi' }] }
})
expect(mocks.saveMessage).toHaveBeenCalledWith({
sessionId: 'session-1',
runtimeResumeToken: 'resume-1',
message: {
id: 'assistant-1',
role: 'assistant',
status: 'success',
data: { parts: [{ type: 'text', text: 'hi' }] },
modelId: 'claude-code::claude-sonnet-4-5'
}
})
})
it('routes runtime events from the selected driver into the active turn', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
interrupt: vi.fn(),
close: vi.fn()
}
const connect = vi.fn().mockResolvedValue(connection)
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect,
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
events.push({ type: 'resume-token', token: 'resume-1' })
await vi.waitFor(() => expect(service.inspect('session-1')).toMatchObject({ resumeToken: 'resume-1' }))
events.push({ type: 'chunk', chunk: { type: 'text-delta', id: 'text-1', delta: 'hello' } })
await expect(reader.read()).resolves.toMatchObject({
value: { type: 'text-delta', id: 'text-1', delta: 'hello' },
done: false
})
events.push({ type: 'turn-complete' })
await expect(reader.read()).resolves.toMatchObject({ done: true })
})
it('surfaces a runtime error event via controller.error and drops trailing chunks (REGRESSION agent-session-3)', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
interrupt: vi.fn(),
close: vi.fn()
}
const connect = vi.fn().mockResolvedValue(connection)
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect,
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalled())
// A runtime `error` event surfaces through the active turn's controller.
events.push({ type: 'error', error: new Error('runtime boom') })
await expect(reader.read()).rejects.toThrow('runtime boom')
// The turn is marked terminal synchronously, so a trailing chunk in the same connection
// loop is dropped instead of being enqueued on the now-errored controller (which would throw).
await vi.waitFor(() => expect(getEntry(service).currentTurn?.terminalStatus).toBe('error'))
events.push({ type: 'chunk', chunk: { type: 'text-delta', id: 't', delta: 'late' } })
await new Promise((resolve) => setTimeout(resolve, 0))
expect(getEntry(service).currentTurn?.terminalStatus).toBe('error')
})
it('passes trace context to the runtime driver and closes the connection after trace turns', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
shouldCloseAfterTurn: () => true,
close: vi.fn()
}
const connect = vi.fn().mockResolvedValue(connection)
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect,
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({
...baseTurnInput,
userMessage: userMessage('user-1'),
traceId: '0'.repeat(32),
rootSpanId: '1'.repeat(16)
})
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() =>
expect(connect).toHaveBeenCalledWith({
sessionId: 'session-1',
agentId: 'agent-1',
modelId: 'claude-code::claude-sonnet-4-5',
resumeToken: undefined,
trace: {
topicId: 'agent-session:session-1',
traceId: '0'.repeat(32),
rootSpanId: '1'.repeat(16),
sessionId: 'session-1',
turnId: handle.turnId,
modelName: 'claude-sonnet-4-5'
}
})
)
void terminalListener(handle).onDone({ status: 'success', isTopicDone: true })
expect(connection.close).toHaveBeenCalledOnce()
expect(getEntry(service).connection).toBeUndefined()
await reader.cancel().catch(() => undefined)
})
it('hydrates the persisted resume token before connecting a cold historical session', async () => {
mocks.getLastRuntimeResumeToken.mockResolvedValue('resume-db')
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
close: vi.fn()
}
const connect = vi.fn().mockResolvedValue(connection)
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect,
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() =>
expect(connect).toHaveBeenCalledWith({
sessionId: 'session-1',
agentId: 'agent-1',
modelId: 'claude-code::claude-sonnet-4-5',
resumeToken: 'resume-db',
trace: undefined
})
)
expect(mocks.getLastRuntimeResumeToken).toHaveBeenCalledWith('session-1')
expect(service.inspect('session-1')).toMatchObject({ resumeToken: 'resume-db' })
service.closeSession('session-1')
await reader.cancel().catch(() => undefined)
})
it('closes the runtime session when the active turn is aborted by the user', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
close: vi.fn()
}
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn().mockResolvedValue(connection),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const controller = new AbortController()
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: controller.signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
controller.abort('user-requested')
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
expect(service.inspect('session-1')).toBeUndefined()
await reader.cancel().catch(() => undefined)
})
it('closes a late runtime connection when the user aborts before connect resolves', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
close: vi.fn()
}
const pendingConnection = createDeferred<typeof connection>()
const connect = vi.fn().mockReturnValue(pendingConnection.promise)
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect,
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const controller = new AbortController()
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: controller.signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connect).toHaveBeenCalledOnce())
controller.abort('user-requested')
expect(service.inspect('session-1')).toBeUndefined()
pendingConnection.resolve(connection)
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
expect(connection.send).not.toHaveBeenCalled()
await reader.cancel().catch(() => undefined)
})
describe('interrupt-when-safe — live follow-up', () => {
it('defers the interrupt while a tool is mid-flight, then fires once the tool settles', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
interrupt: vi.fn().mockResolvedValue(undefined),
close: vi.fn()
}
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn().mockResolvedValue(connection),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
// A tool is now in flight — the turn is not safe to interrupt.
events.push({ type: 'chunk', chunk: { type: 'tool-input-start', toolCallId: 'tool-1' } })
await vi.waitFor(() => expect(getEntry(service).currentTurn.activeToolIds.has('tool-1')).toBe(true))
// The follow-up queues but must NOT interrupt while the tool runs.
service.enqueueUserMessage('session-1', userMessage('user-2'))
await new Promise((resolve) => setTimeout(resolve, 0))
expect(connection.interrupt).not.toHaveBeenCalled()
expect(mocks.pauseRuntimeTurn).not.toHaveBeenCalled()
// Tool settles → now safe → interrupt fires and the runtime turn is paused.
events.push({ type: 'chunk', chunk: { type: 'tool-output-available', toolCallId: 'tool-1' } })
await vi.waitFor(() => expect(connection.interrupt).toHaveBeenCalledOnce())
expect(mocks.pauseRuntimeTurn).toHaveBeenCalledWith('agent-session:session-1', 'agent-runtime-interrupt')
service.closeSession('session-1')
await reader.cancel().catch(() => undefined)
})
it('interrupts immediately on the next microtask when no tool is active', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
interrupt: vi.fn().mockResolvedValue(undefined),
close: vi.fn()
}
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn().mockResolvedValue(connection),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: new AbortController().signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
// No tool in flight (activeToolIds empty) → the queued follow-up interrupts on the next microtask.
expect(getEntry(service).currentTurn.activeToolIds.size).toBe(0)
service.enqueueUserMessage('session-1', userMessage('user-2'))
expect(connection.interrupt).not.toHaveBeenCalled()
await vi.waitFor(() => expect(connection.interrupt).toHaveBeenCalledOnce())
expect(mocks.pauseRuntimeTurn).toHaveBeenCalledWith('agent-session:session-1', 'agent-runtime-interrupt')
service.closeSession('session-1')
await reader.cancel().catch(() => undefined)
})
})
it('keeps the runtime session alive when a steer interrupt pauses the turn', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
close: vi.fn()
}
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn().mockResolvedValue(connection),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const controller = new AbortController()
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: controller.signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
// The steer path marks the turn before aborting; the abort reason is irrelevant.
getEntry(service).currentTurn.interruptRequested = true
controller.abort()
await expect(reader.read()).resolves.toMatchObject({ done: true })
expect(connection.close).not.toHaveBeenCalled()
expect(service.inspect('session-1')).toMatchObject({
sessionId: 'session-1',
status: 'active'
})
service.closeSession('session-1')
})
it('tears the session down on abort with an interrupt-looking reason when none was requested', async () => {
const events = createAsyncQueue<any>()
const connection = {
events: events.iterable,
send: vi.fn(),
close: vi.fn()
}
runtimeDriverRegistry.register({
type: 'test-runtime',
capabilities: ['agent-session'],
connect: vi.fn().mockResolvedValue(connection),
validateSession: vi.fn(),
listAvailableTools: vi.fn().mockResolvedValue([])
})
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
const controller = new AbortController()
const stream = service.openTurnStream({
sessionId: 'session-1',
turnId: handle.turnId,
signal: controller.signal
})
const reader = stream.getReader()
await expect(reader.read()).resolves.toMatchObject({ value: { type: 'start' }, done: false })
await vi.waitFor(() => expect(connection.send).toHaveBeenCalledWith({ message: userMessage('user-1') }))
// Reason matches the old interrupt sentinel, but no interrupt was requested —
// teardown is driven by `interruptRequested`, not the signal reason.
controller.abort('agent-runtime-interrupt')
await vi.waitFor(() => expect(connection.close).toHaveBeenCalledOnce())
expect(service.inspect('session-1')).toBeUndefined()
await reader.cancel().catch(() => undefined)
})
it('persists errored assistant turns with the latest resume token', async () => {
const service = new AgentSessionRuntimeService()
const handle = service.beginTurn({ ...baseTurnInput, userMessage: userMessage('user-1') })
getEntry(service).lastResumeToken = 'resume-init'
await persistenceListener(handle).onError({
status: 'error',
isTopicDone: true,
error: { name: 'Error', message: 'boom' },
finalMessage: { id: 'assistant-1', role: 'assistant', parts: [] }
})
expect(mocks.saveMessage).toHaveBeenCalledWith({
sessionId: 'session-1',
runtimeResumeToken: 'resume-init',
message: {
id: 'assistant-1',
role: 'assistant',
status: 'error',
data: { parts: [{ type: 'data-error', data: { name: 'Error', message: 'boom' } }] },
modelId: 'claude-code::claude-sonnet-4-5'
}
})
})
it('starts queued turns with runtime request metadata and assistant seed', async () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
const entry = getEntry(service)
entry.lastResumeToken = 'resume-1'
entry.currentTurn.activeToolIds.add('tool-1')
entry.pendingTurns.push(userMessage('user-2'))
await (service as any).startNextTurn(entry)
const savedMessage = mocks.saveMessage.mock.calls[0][0].message
expect(mocks.saveMessage).toHaveBeenCalledWith({
sessionId: 'session-1',
message: {
role: 'assistant',
status: 'pending',
data: { parts: [] },
modelId: 'claude-code::claude-sonnet-4-5',
traceId: expect.any(String)
}
})
expect(mocks.spanCacheSetTopicId).toHaveBeenCalledWith(savedMessage.traceId, 'agent-session:session-1')
expect(mocks.startRuntimeTurn).toHaveBeenCalledWith({
topicId: 'agent-session:session-1',
modelId: 'claude-code::claude-sonnet-4-5',
rootSpan: expect.anything(),
request: {
chatId: 'agent-session:session-1',
trigger: 'submit-message',
messageId: 'generated-message-id',
messages: [
{ id: 'user-2', role: 'user', parts: [{ type: 'text', text: 'hello' }] },
{ id: 'generated-message-id', role: 'assistant', parts: [] }
],
runtime: { kind: 'agent-session', sessionId: 'session-1', turnId: expect.any(String) }
},
listeners: [
expect.objectContaining({ id: expect.stringContaining('persistence:agents-db:') }),
expect.objectContaining({ id: 'agent-runtime:session-1' }),
expect.objectContaining({ id: 'persistence:trace:agent-session:session-1' })
]
})
const request = mocks.startRuntimeTurn.mock.calls[0][0].request
expect(request.messageId).toBe(request.messages[1].id)
expect(getEntry(service).currentTurn.trace).toMatchObject({
topicId: 'agent-session:session-1',
traceId: savedMessage.traceId,
rootSpanId: expect.any(String),
sessionId: 'session-1',
turnId: request.runtime.turnId,
modelName: 'claude-sonnet-4-5'
})
})
it('surfaces the error and settles the turn when the next-turn placeholder save rejects (R3)', async () => {
const service = new AgentSessionRuntimeService()
service.beginTurn(baseTurnInput)
const entry = getEntry(service)
const queued = userMessage('user-2')
entry.pendingTurns.push(queued)
const saveError = new Error('db down')
mocks.saveMessage.mockRejectedValueOnce(saveError)
// The placeholder save failed: re-queuing would just fail again and the idle TTL would
// silently clear it, so the message is dropped, the failure is surfaced to the live renderer,
// and the turn is settled to `error` (not left silently idle).
await expect((service as any).startNextTurn(entry)).resolves.toBeUndefined()
expect(entry.pendingTurns).toEqual([])
expect(mocks.startRuntimeTurn).not.toHaveBeenCalled()
expect(mocks.broadcastTopicError).toHaveBeenCalledWith(
entry.topicId,
entry.modelId,
expect.objectContaining({ message: expect.stringContaining('db down') })
)
expect(entry.status).toBe('idle')
expect(entry.lastTerminalStatus).toBe('error')
})
})

View File

@@ -0,0 +1,63 @@
/**
* Agent-session DB backend — writes assistant turns to the `agent_session_message`
* table via `agentSessionMessageService`. The user message is persisted
* by AgentChatContextProvider before streaming starts (not here).
*
* The listener folds any error into `finalMessage.parts` upstream, so a
* single `persistAssistant` handles success / paused / error uniformly.
*/
import { agentSessionMessageService } from '@data/services/AgentSessionMessageService'
import type { CherryMessagePart, CherryUIMessage } from '@shared/data/types/message'
import type { UniqueModelId } from '@shared/data/types/model'
import { v7 as uuidv7 } from 'uuid'
import {
finalizeInterruptedParts,
type PersistAssistantInput,
type PersistenceBackend
} from '../../streamManager/persistence/PersistenceBackend'
export interface AgentSessionMessageBackendOptions {
/** Cherry Studio agent-session id. */
sessionId: string
/** Model id used for this assistant message. */
modelId?: UniqueModelId
/** Opaque runtime resume token persisted for future recovery; `undefined` when unknown. */
runtimeResumeToken?: string | (() => string | undefined)
/** Post-success hook — typically session auto-rename. */
afterPersist?: (finalMessage: CherryUIMessage) => Promise<void>
}
export class AgentSessionMessageBackend implements PersistenceBackend {
readonly kind = 'agents-db'
readonly afterPersist?: (finalMessage: CherryUIMessage) => Promise<void>
constructor(private readonly opts: AgentSessionMessageBackendOptions) {
this.afterPersist = opts.afterPersist
}
async persistAssistant(input: PersistAssistantInput): Promise<void> {
const { finalMessage, status, stats } = input
const parts = finalizeInterruptedParts((finalMessage?.parts ?? []) as CherryMessagePart[], status)
const runtimeResumeToken = this.getRuntimeResumeToken()
await agentSessionMessageService.saveMessage({
sessionId: this.opts.sessionId,
...(runtimeResumeToken ? { runtimeResumeToken } : {}),
message: {
id: finalMessage?.id ?? uuidv7(),
role: 'assistant',
status,
data: { parts },
modelId: this.opts.modelId,
...(stats ? { stats } : {})
}
})
}
private getRuntimeResumeToken(): string | undefined {
return typeof this.opts.runtimeResumeToken === 'function'
? this.opts.runtimeResumeToken()
: this.opts.runtimeResumeToken
}
}

View File

@@ -0,0 +1,19 @@
const AGENT_SESSION_PREFIX = 'agent-session:'
/** Check if a topicId represents an agent session (vs a normal chat). */
export function isAgentSessionTopic(topicId: string): boolean {
return topicId.startsWith(AGENT_SESSION_PREFIX)
}
/** Extract the agent session ID from a topic ID. Throws if not an agent session topic. */
export function extractAgentSessionId(topicId: string): string {
if (!isAgentSessionTopic(topicId)) {
throw new Error(`Not an agent session topicId: ${topicId}`)
}
return topicId.slice(AGENT_SESSION_PREFIX.length)
}
/** Build the topic id for an agent session. */
export function buildAgentSessionTopicId(sessionId: string): string {
return `${AGENT_SESSION_PREFIX}${sessionId}`
}

View File

@@ -0,0 +1,19 @@
import { application } from '@main/core/application'
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
import { IpcChannel } from '@shared/IpcChannel'
import { AgentTaskJobHandler } from './AgentTaskJobHandler'
@Injectable('AgentJobsService')
@ServicePhase(Phase.WhenReady)
@DependsOn(['JobManager'])
export class AgentJobsService extends BaseService {
protected async onInit(): Promise<void> {
const jobManager = application.get('JobManager')
jobManager.registerHandler('agent.task', AgentTaskJobHandler)
this.ipcHandle(IpcChannel.Ai_Agent_RunTask, async (_event, taskId: string) => {
return jobManager.triggerJobScheduleNowById(taskId)
})
}
}

View File

@@ -0,0 +1,84 @@
/**
* Job handler for `agent.task` — scheduled agent prompts.
*
* Thin metadata + execute wrapper; business logic lives in `./runAgentTask`.
* Failure backstop: after three consecutive failed terminal jobs on the same
* schedule, pauses the schedule via `JobManager.pauseJobScheduleById`. The
* `jobTable` rows are the single source of truth — no in-memory counter
* (the legacy `SchedulerService.consecutiveErrors` map reset on every process
* restart, making the breaker effectively unreachable in practice).
*/
import { application } from '@application'
import { jobService } from '@data/services/JobService'
import { loggerService } from '@logger'
import type { JobHandler } from '@main/core/job/types'
import { runAgentTask } from './runAgentTask'
declare module '@main/core/job/jobRegistry' {
interface JobRegistry {
'agent.task': {
agentId: string
prompt: string
/** Per-task timeout in minutes. Enforced inside `runAgentTask`; handler-level
* `defaultTimeoutMs` is intentionally unset so each task may set its own value. */
timeoutMinutes: number
}
}
}
const logger = loggerService.withContext('AgentTaskJobHandler')
const RECENT_TERMINAL_WINDOW = 3
export const AgentTaskJobHandler: JobHandler<{ agentId: string; prompt: string; timeoutMinutes: number }> = {
/**
* 'retry': non-terminal jobs from a previous run are re-pended on startup
* so the recovered job dispatches against the latest agent configuration.
* This matches the legacy poll-loop semantics where a task missed by a
* crash was simply picked up on the next 60s tick.
*/
recovery: 'retry',
/**
* Per-agent serialization queue: a single agent never runs two scheduled
* tasks concurrently (Claude Code subprocess + workspace state would
* collide). Cross-agent parallelism is unaffected.
*/
defaultQueue: (input) => `agent:${input.agentId}`,
defaultConcurrency: 1,
/**
* Schedule-driven tasks do not retry inside the Job runtime — failure
* surfaces to `onSettled` and the circuit breaker decides whether to pause.
* Re-attempting an LLM call automatically is rarely helpful and can rack
* up token spend without diagnostic value.
*/
defaultRetryPolicy: { maxAttempts: 1, backoff: 'none', baseDelayMs: 0, maxDelayMs: 0 },
async execute(ctx) {
return await runAgentTask(ctx)
},
async onSettled(event) {
if (event.status !== 'failed' || !event.scheduleId) return
const recent = await jobService.listRecentTerminalByScheduleId(event.scheduleId, RECENT_TERMINAL_WINDOW)
if (recent.length < RECENT_TERMINAL_WINDOW) return
if (!recent.every((j) => j.status === 'failed')) return
logger.warn('Agent task schedule failed in last N terminal runs — pausing', {
scheduleId: event.scheduleId,
window: RECENT_TERMINAL_WINDOW
})
try {
await application.get('JobManager').pauseJobScheduleById(event.scheduleId)
} catch (err) {
logger.error('Failed to pause schedule after consecutive failures', err as Error, {
scheduleId: event.scheduleId
})
}
}
}

View File

@@ -0,0 +1,176 @@
import type { JobContext, JobSettledEvent } from '@main/core/job/types'
import type { JobSnapshot } from '@shared/data/api/schemas/jobs'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@application', async () => {
const mod = await import('@test-mocks/main/application')
return mod.mockApplicationFactory()
})
vi.mock('@data/services/JobService', () => ({
jobService: {
listRecentTerminalByScheduleId: vi.fn()
}
}))
vi.mock('../runAgentTask', () => ({
runAgentTask: vi.fn()
}))
import { application } from '@application'
import { jobService } from '@data/services/JobService'
import { AgentTaskJobHandler } from '../AgentTaskJobHandler'
import { runAgentTask } from '../runAgentTask'
function makeTerminal(status: 'completed' | 'failed' | 'cancelled', id = `job-${status}`): JobSnapshot {
return {
id,
type: 'agent.task',
status,
priority: 0,
queue: 'agent:a1',
idempotencyKey: null,
scheduleId: 's1',
scheduledAt: '2026-05-20T00:00:00.000Z',
startedAt: '2026-05-20T00:00:01.000Z',
finishedAt: '2026-05-20T00:00:02.000Z',
attempt: 0,
maxAttempts: 1,
input: {},
output: null,
error: null,
parentId: null,
cancelRequested: false,
metadata: {},
timeoutMs: null,
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:02.000Z'
}
}
function makeSettled(overrides: Partial<JobSettledEvent>): JobSettledEvent {
return {
jobId: 'job-1',
type: 'agent.task',
scheduleId: 's1',
status: 'failed',
output: null,
error: { code: 'TEST', message: 'boom', retryable: false },
attempt: 0,
...overrides
} as JobSettledEvent
}
describe('AgentTaskJobHandler', () => {
const pauseSpy = vi.fn()
beforeEach(() => {
vi.mocked(application.get).mockImplementation((name: string) => {
if (name === 'JobManager') return { pauseJobScheduleById: pauseSpy } as never
throw new Error(`Unexpected application.get('${name}')`)
})
pauseSpy.mockReset()
pauseSpy.mockResolvedValue(true)
vi.mocked(jobService.listRecentTerminalByScheduleId).mockReset()
vi.mocked(runAgentTask).mockReset()
})
afterEach(() => {
vi.clearAllMocks()
})
describe('metadata', () => {
it('declares per-agent queue + concurrency 1 + retry-once policy', () => {
expect(AgentTaskJobHandler.recovery).toBe('retry')
expect(AgentTaskJobHandler.defaultConcurrency).toBe(1)
expect(AgentTaskJobHandler.defaultRetryPolicy).toEqual({
maxAttempts: 1,
backoff: 'none',
baseDelayMs: 0,
maxDelayMs: 0
})
expect(AgentTaskJobHandler.defaultQueue?.({ agentId: 'a-42', prompt: 'x', timeoutMinutes: 2 })).toBe('agent:a-42')
})
})
describe('execute', () => {
it('delegates to runAgentTask with the JobContext', async () => {
vi.mocked(runAgentTask).mockResolvedValueOnce({ sessionId: 'sess-1', result: 'ok' })
const ctx = { jobId: 'j1', input: { agentId: 'a', prompt: 'p', timeoutMinutes: 2 } } as JobContext<{
agentId: string
prompt: string
timeoutMinutes: number
}>
const out = await AgentTaskJobHandler.execute(ctx)
expect(out).toEqual({ sessionId: 'sess-1', result: 'ok' })
expect(runAgentTask).toHaveBeenCalledWith(ctx)
})
})
describe('onSettled circuit breaker', () => {
it('pauses schedule after 3 consecutive failures', async () => {
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
makeTerminal('failed', 'a'),
makeTerminal('failed', 'b'),
makeTerminal('failed', 'c')
])
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
expect(jobService.listRecentTerminalByScheduleId).toHaveBeenCalledWith('s1', 3)
expect(pauseSpy).toHaveBeenCalledWith('s1')
})
it('does not pause when the latest is failed but a recent one is completed', async () => {
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
makeTerminal('failed', 'a'),
makeTerminal('completed', 'b'),
makeTerminal('failed', 'c')
])
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
expect(pauseSpy).not.toHaveBeenCalled()
})
it('does not pause when the recent-terminal window is not yet full', async () => {
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
makeTerminal('failed', 'a'),
makeTerminal('failed', 'b')
])
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))
expect(pauseSpy).not.toHaveBeenCalled()
})
it('does not act on non-failed terminal events', async () => {
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'completed' }))
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'cancelled' }))
expect(jobService.listRecentTerminalByScheduleId).not.toHaveBeenCalled()
expect(pauseSpy).not.toHaveBeenCalled()
})
it('does not act when the failed job has no scheduleId (ad-hoc enqueue)', async () => {
await AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed', scheduleId: null }))
expect(jobService.listRecentTerminalByScheduleId).not.toHaveBeenCalled()
expect(pauseSpy).not.toHaveBeenCalled()
})
it('swallows pauseJobScheduleById errors so onSettled cannot throw', async () => {
vi.mocked(jobService.listRecentTerminalByScheduleId).mockResolvedValueOnce([
makeTerminal('failed', 'a'),
makeTerminal('failed', 'b'),
makeTerminal('failed', 'c')
])
pauseSpy.mockRejectedValueOnce(new Error('db lost'))
await expect(AgentTaskJobHandler.onSettled?.(makeSettled({ status: 'failed' }))).resolves.not.toThrow()
})
})
})

View File

@@ -0,0 +1,353 @@
/**
* Phase 1 coverage: focuses on the pure branches that do not engage the
* Claude Code subprocess (heartbeat skip + agent-not-found). The full
* streaming path is exercised by integration tests / Phase 5 manual e2e.
*
* Each fire creates a fresh session — there is no cross-fire session reuse.
*/
import type { JobContext } from '@main/core/job/types'
import type { AgentEntity } from '@shared/data/api/schemas/agents'
import type { JobSnapshot } from '@shared/data/api/schemas/jobs'
import type { AgentSessionEntity } from '@shared/data/api/schemas/sessions'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
const { mockAbort, mockGetAdapter, mockStartRun, captured } = vi.hoisted(() => {
const captured: { listeners: Array<Record<string, (arg?: unknown) => void>> } = { listeners: [] }
return {
mockAbort: vi.fn(),
mockGetAdapter: vi.fn(() => undefined),
mockStartRun: vi.fn(async (opts: { listeners: typeof captured.listeners }) => {
captured.listeners = opts.listeners
}),
captured
}
})
vi.mock('@application', async () => {
const mod = await import('@test-mocks/main/application')
return mod.mockApplicationFactory({
// ChannelManager + AiStreamManager aren't in the default mock service set; the
// streaming path (post heartbeat-skip) reads both, so wire minimal stubs here.
ChannelManager: { getAdapter: mockGetAdapter },
AiStreamManager: { abort: mockAbort }
} as never)
})
vi.mock('@main/ai/streamManager/api/startAgentSessionRun', () => ({
startAgentSessionRun: mockStartRun
}))
vi.mock('@data/services/AgentChannelService', () => ({
agentChannelService: { getSubscribedChannels: vi.fn() }
}))
vi.mock('@data/services/AgentService', () => ({
agentService: { getAgent: vi.fn() }
}))
vi.mock('@data/services/SessionService', () => ({
sessionService: { createSession: vi.fn(), findAgentWorkspacePath: vi.fn() }
}))
vi.mock('@data/services/JobScheduleService', () => ({
jobScheduleService: { getById: vi.fn() }
}))
vi.mock('@data/services/JobService', () => ({
jobService: { getById: vi.fn() }
}))
vi.mock('@main/ai/agents/cherryclaw/heartbeat', () => ({
readHeartbeat: vi.fn()
}))
import { agentChannelService } from '@data/services/AgentChannelService'
import { agentService } from '@data/services/AgentService'
import { jobScheduleService } from '@data/services/JobScheduleService'
import { jobService } from '@data/services/JobService'
import { sessionService } from '@data/services/SessionService'
import { readHeartbeat } from '@main/ai/agents/cherryclaw/heartbeat'
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
import { runAgentTask } from '../runAgentTask'
function makeJobSnapshot(scheduleId: string | null = 's1'): JobSnapshot {
return {
id: 'j1',
type: 'agent.task',
status: 'running',
priority: 0,
queue: 'agent:a1',
idempotencyKey: null,
scheduleId,
scheduledAt: '2026-05-20T00:00:00.000Z',
startedAt: '2026-05-20T00:00:00.000Z',
finishedAt: null,
attempt: 0,
maxAttempts: 1,
input: {},
output: null,
error: null,
parentId: null,
cancelRequested: false,
metadata: {},
timeoutMs: null,
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:00.000Z'
}
}
function makeCtx(overrides: Partial<JobContext<{ agentId: string; prompt: string; timeoutMinutes: number }>> = {}) {
return {
jobId: 'j1',
input: { agentId: 'a1', prompt: '__heartbeat__', timeoutMinutes: 2 },
attempt: 0,
signal: new AbortController().signal,
metadata: {},
patchMetadata: vi.fn(),
reportProgress: vi.fn(),
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() } as never,
...overrides
} as JobContext<{ agentId: string; prompt: string; timeoutMinutes: number }>
}
function makeAgent(config: Record<string, unknown> = {}): AgentEntity {
return {
id: 'a1',
type: 'claude-code',
name: 'Agent A',
model: 'sonnet' as never,
configuration: config as never,
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:00.000Z',
modelName: null
}
}
function makeSession(workspacePath: string | null = '/ws/a'): AgentSessionEntity {
return {
id: 'sess-new',
agentId: 'a1',
name: 'Scheduled task',
workspaceId: workspacePath ? 'ws-1' : null,
workspace: workspacePath
? {
id: 'ws-1',
name: 'ws',
path: workspacePath,
orderKey: 'k',
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:00.000Z'
}
: null,
orderKey: 'k',
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:00.000Z'
} as AgentSessionEntity
}
function makeSchedule(name: string | null = 'heartbeat') {
return {
id: 's1',
type: 'agent.task',
name,
trigger: { kind: 'interval', ms: 60_000 },
jobInputTemplate: {},
enabled: true,
nextRun: null,
lastRun: null,
catchUpPolicy: { kind: 'skip-missed' },
metadata: {},
createdAt: '2026-05-20T00:00:00.000Z',
updatedAt: '2026-05-20T00:00:00.000Z'
} as never
}
describe('runAgentTask', () => {
beforeEach(() => {
vi.mocked(jobService.getById).mockReset()
vi.mocked(jobScheduleService.getById).mockReset()
vi.mocked(agentService.getAgent).mockReset()
vi.mocked(sessionService.createSession).mockReset()
vi.mocked(sessionService.findAgentWorkspacePath).mockReset()
vi.mocked(readHeartbeat).mockReset()
vi.mocked(agentChannelService.getSubscribedChannels).mockReset().mockResolvedValue([])
mockStartRun.mockClear()
mockAbort.mockClear()
mockGetAdapter.mockClear()
captured.listeners = []
})
afterEach(() => {
vi.clearAllMocks()
})
it('throws when the agent cannot be found', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(null as never)
await expect(runAgentTask(makeCtx())).rejects.toThrow('Agent not found: a1')
})
// A disabled heartbeat must short-circuit BEFORE createSession — that call also
// lazily provisions a workspace on first fire, so creating a session for a fire
// we're going to drop would accrete a session row (and workspace) every interval.
it('skips a disabled heartbeat WITHOUT creating a session', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: false }))
const out = await runAgentTask(makeCtx())
expect(out).toEqual({ sessionId: null, result: 'Skipped (disabled)' })
expect(sessionService.createSession).not.toHaveBeenCalled()
expect(readHeartbeat).not.toHaveBeenCalled()
})
it('skips an enabled heartbeat with no workspace WITHOUT creating a session', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce(null)
const out = await runAgentTask(makeCtx())
expect(out).toEqual({ sessionId: null, result: 'Skipped (no file)' })
expect(sessionService.createSession).not.toHaveBeenCalled()
expect(readHeartbeat).not.toHaveBeenCalled()
})
it('skips an enabled heartbeat with no heartbeat.md WITHOUT creating a session', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce('/ws/a')
vi.mocked(readHeartbeat).mockResolvedValueOnce(undefined)
const out = await runAgentTask(makeCtx())
expect(out).toEqual({ sessionId: null, result: 'Skipped (no file)' })
expect(sessionService.createSession).not.toHaveBeenCalled()
expect(readHeartbeat).toHaveBeenCalledWith('/ws/a')
})
it('creates a session and runs when an enabled heartbeat has content', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('heartbeat'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent({ heartbeat_enabled: true }))
vi.mocked(sessionService.findAgentWorkspacePath).mockResolvedValueOnce('/ws/a')
vi.mocked(readHeartbeat).mockResolvedValueOnce('check the inbox')
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
const promise = runAgentTask(makeCtx())
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
captured.listeners[0].onDone({ status: 'completed' })
await promise
expect(readHeartbeat).toHaveBeenCalledWith('/ws/a')
expect(sessionService.createSession).toHaveBeenCalledWith({ agentId: 'a1', name: 'heartbeat' })
})
// C1 (agents-jobs-3): a `text-delta` chunk's payload is on `.delta`, not `.text`.
// The previous `as { text }` cast silently accumulated nothing, so every run
// persisted the `'Completed'` fallback instead of the model's reply.
it('accumulates text-delta chunks via .delta into the result', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } }))
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
const sentinel = captured.listeners[0]
sentinel.onChunk({ type: 'text-delta', delta: 'Hello ' })
sentinel.onChunk({ type: 'text-delta', delta: 'world' })
sentinel.onChunk({ type: 'reasoning-delta', delta: 'ignored' })
sentinel.onDone({ status: 'completed' })
const out = await promise
expect(out).toEqual({ sessionId: 'sess-new', result: 'Hello world' })
})
// agents-jobs-4: on a non-abort error, a subscribed channel must be notified exactly
// once. The channel listener's generic `Error: …` is suppressed for task runs so only
// the richer `[Task failed]` summary from notifyTaskError is delivered (no double-send).
it('notifies a subscribed channel exactly once on a non-abort run error', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot('s1'))
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
vi.mocked(agentChannelService.getSubscribedChannels).mockResolvedValueOnce([{ id: 'ch1' }] as never)
const adapter = {
channelId: 'ch1',
connected: true,
notifyChatIds: ['chat-1'],
sendMessage: vi.fn<(chatId: string, text: string) => Promise<void>>(async () => {}),
onTextUpdate: vi.fn(async () => {}),
onStreamComplete: vi.fn(async () => true)
}
mockGetAdapter.mockReturnValue(adapter as never)
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } }))
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
// Simulate the stream manager dispatching the error to every listener (sentinel + channel).
const errorResult = { error: new Error('boom'), status: 'error' }
for (const listener of captured.listeners) {
listener.onError?.(errorResult as never)
}
await expect(promise).rejects.toThrow('boom')
// Exactly one channel message, and it's the task-framed summary — not the bare `Error: …`.
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
expect(adapter.sendMessage.mock.calls[0][1]).toContain('[Task failed]')
expect(adapter.sendMessage.mock.calls[0][1]).not.toMatch(/^Error:/)
})
// C2 (agents-jobs-1) + agents-jobs-7: aborting the run (JobManager cancel or
// per-task timeout) must abort the upstream stream AND settle the handler
// promise — otherwise it leaks until the JobManager force-finalize timeout.
it('aborts the upstream stream and rejects when the run signal aborts', async () => {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
const controller = new AbortController()
const promise = runAgentTask(
makeCtx({ signal: controller.signal, input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0 } })
)
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
controller.abort(new Error('cancelled by manager'))
await expect(promise).rejects.toThrow('cancelled by manager')
expect(mockAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-new'), 'cancelled by manager')
})
// agents-jobs-5: a non-zero `timeoutMinutes` arms a per-task timeout timer in
// makeRunSignal. When the stream never settles, the timer must fire, abort the
// upstream stream, and reject the handler with the timeout error.
it('aborts the upstream stream and rejects when the per-task timeout fires', async () => {
vi.useFakeTimers()
try {
vi.mocked(jobService.getById).mockResolvedValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockResolvedValueOnce(makeSchedule('daily-summary'))
vi.mocked(agentService.getAgent).mockResolvedValueOnce(makeAgent())
vi.mocked(sessionService.createSession).mockResolvedValueOnce(makeSession('/ws/a'))
const promise = runAgentTask(makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 1 } }))
const assertion = expect(promise).rejects.toThrow('Task timed out after 1 minute(s)')
// Flush the awaited setup chain (getById/getAgent/createSession/startRun) and
// arm the timer, then advance past the 1-minute timeout so it fires. Never
// settle the stream — the timeout is the only thing that resolves the run.
await vi.advanceTimersByTimeAsync(60_000)
await assertion
expect(mockAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-new'), 'Task timed out after 1 minute(s)')
} finally {
vi.useRealTimers()
}
})
})

View File

@@ -62,7 +62,7 @@ export interface BuiltinAgentConfig {
* Writes .claude/skills/ and .claude/plugins.json to the agent's
* working directory so the SDK can auto-discover them.
*
* @param workspacePath - The agent's working directory (accessible_paths[0])
* @param workspacePath - The agent session's workspace directory
* @param builtinRole - The built-in role identifier ('assistant' or 'skill-creator')
* @returns The parsed agent.json config, or undefined if not found
*/

View File

@@ -14,11 +14,11 @@ vi.mock('node:fs/promises', () => ({
import { readdir, readFile, stat } from 'node:fs/promises'
import type { CherryClawConfiguration } from '@types'
import type { AgentConfiguration } from '@shared/data/types/agent'
import { PromptBuilder } from '../prompt'
const baseConfig: CherryClawConfiguration = {
const baseConfig: AgentConfiguration = {
permission_mode: 'bypassPermissions',
max_turns: 100,
env_vars: {},

View File

@@ -2,7 +2,7 @@ import { readdir, readFile, stat } from 'node:fs/promises'
import path from 'node:path'
import { loggerService } from '@logger'
import type { CherryClawConfiguration } from '@types'
import type { AgentConfiguration } from '@shared/data/types/agent'
import { BOOTSTRAP_INSTRUCTIONS, SOUL_CONTENT_THRESHOLD } from './seedWorkspace'
@@ -142,15 +142,15 @@ ${sections}`
* instructions. Returns a synchronous string no I/O.
*
* Memory files layout (Soul Mode only):
* {workspace}/soul.md personality, tone, communication style
* {workspace}/user.md user profile, preferences, context
* {workspace}/SOUL.md personality, tone, communication style
* {workspace}/USER.md user profile, preferences, context
* {workspace}/memory/FACT.md durable project knowledge, technical decisions
* {workspace}/memory/JOURNAL.jsonl timestamped event log (managed by memory tool)
*/
export class PromptBuilder {
private cache = new Map<string, CacheEntry>()
async buildSystemPrompt(workspacePath: string, config?: CherryClawConfiguration): Promise<string> {
async buildSystemPrompt(workspacePath: string, config?: AgentConfiguration): Promise<string> {
const parts: string[] = []
// Basic prompt: workspace system.md (case-insensitive) > embedded default
@@ -227,7 +227,7 @@ ${content}
* - If SOUL.md has substantial non-template content, skip (legacy agent migration).
* - Otherwise, run bootstrap.
*/
private async shouldRunBootstrap(workspacePath: string, config?: CherryClawConfiguration): Promise<boolean> {
private async shouldRunBootstrap(workspacePath: string, config?: AgentConfiguration): Promise<boolean> {
if (config?.bootstrap_completed === true) {
return false
}

View File

@@ -63,7 +63,7 @@ Your goal in this conversation is to:
- Rename yourself using \`mcp__claw__config\` (action: "rename", name: the chosen name)
- Update \`SOUL.md\` with your role definition, personality, tone, principles, and boundaries using the Edit tool
- Update \`USER.md\` with everything you learned about the user using the Edit tool
- Log the bootstrap completion using \`mcp__claw__memory\` (append action, tag: "bootstrap")
- Log the bootstrap completion using \`mcp__agent-memory__memory\` (append action, tag: "bootstrap")
- Mark bootstrap as complete using \`mcp__claw__config\` (action: "complete_bootstrap")
Guidelines:

View File

@@ -0,0 +1,251 @@
/**
* Business logic for `agent.task` jobs — owned by `AgentTaskJobHandler`.
*
* Each fire creates a fresh agent session. Per-fire sessions are recorded in
* `job.output.sessionId` for audit only — there is no cross-fire session
* reuse pointer on the schedule. Scheduled tasks are discrete background
* invocations (heartbeat, periodic summary, polling), not conversations, so
* carrying context across fires would only stuff the model's window with
* stale state. Persistent agent memory belongs in workspace files
* (`heartbeat.md`, agent memory) instead of session history.
*/
import { agentChannelService } from '@data/services/AgentChannelService'
import { agentService } from '@data/services/AgentService'
import { jobScheduleService } from '@data/services/JobScheduleService'
import { jobService } from '@data/services/JobService'
import { sessionService } from '@data/services/SessionService'
import { loggerService } from '@logger'
import { readHeartbeat } from '@main/ai/agents/cherryclaw/heartbeat'
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
import { ChannelAdapterListener, type StreamListener } from '@main/ai/streamManager'
import { startAgentSessionRun } from '@main/ai/streamManager/api/startAgentSessionRun'
import { application } from '@main/core/application'
import type { JobContext } from '@main/core/job/types'
const logger = loggerService.withContext('runAgentTask')
const HEARTBEAT_PROMPT_SENTINEL = '__heartbeat__'
const HEARTBEAT_TASK_NAME = 'heartbeat'
export type AgentTaskInput = {
agentId: string
prompt: string
timeoutMinutes: number
}
export type AgentTaskOutput = {
/** Session created for this fire. Persisted to `jobTable.output` purely as
* an audit trail — the task scheduler never reads this back for continuity. */
sessionId: string | null
/** First 200 chars of the assistant reply, or a status marker for skipped runs. */
result: string
}
/** Combine the JobManager-provided abort signal with an optional per-task timeout. */
function makeRunSignal(
outerSignal: AbortSignal,
timeoutMinutes: number | undefined
): { signal: AbortSignal; dispose: () => void } {
if (!timeoutMinutes || timeoutMinutes <= 0) {
return { signal: outerSignal, dispose: () => {} }
}
// Own the timeout so `dispose()` can actually release the timer on normal
// completion (an `AbortSignal.timeout` keeps a live timer until it fires).
const timeoutController = new AbortController()
const timer = setTimeout(
() => timeoutController.abort(new Error(`Task timed out after ${timeoutMinutes} minute(s)`)),
timeoutMinutes * 60_000
)
const signal = AbortSignal.any([outerSignal, timeoutController.signal])
return { signal, dispose: () => clearTimeout(timer) }
}
export async function runAgentTask(ctx: JobContext<AgentTaskInput>): Promise<AgentTaskOutput> {
const { agentId, prompt, timeoutMinutes } = ctx.input
// schedule-fired jobs carry `scheduleId` on the row; manual ad-hoc enqueues
// (no schedule) degrade gracefully: skip channel notification.
const jobSnapshot = await jobService.getById(ctx.jobId)
const scheduleId = jobSnapshot?.scheduleId ?? null
const scheduleSnapshot = scheduleId ? await jobScheduleService.getById(scheduleId) : null
const taskName = scheduleSnapshot?.name ?? null
const agent = await agentService.getAgent(agentId)
if (!agent) {
throw new Error(`Agent not found: ${agentId}`)
}
const config = agent.configuration ?? {}
const isHeartbeat = taskName === HEARTBEAT_TASK_NAME && prompt === HEARTBEAT_PROMPT_SENTINEL
let effectivePrompt = prompt
// All heartbeat skip decisions happen BEFORE we create a session — `createSession`
// lazily provisions a workspace, so creating one for a fire we're going to drop
// accretes a session row (and workspace) every interval. The agent's workspace is
// shared across its sessions, so we can read `heartbeat.md` without creating one.
if (isHeartbeat) {
if (config.heartbeat_enabled === false) {
logger.debug('Heartbeat skipped (disabled)', { agentId, scheduleId })
return { sessionId: null, result: 'Skipped (disabled)' }
}
const workspacePath = await sessionService.findAgentWorkspacePath(agentId)
if (!workspacePath) {
logger.debug('Heartbeat skipped (no workspace)', { agentId, scheduleId })
return { sessionId: null, result: 'Skipped (no file)' }
}
const content = await readHeartbeat(workspacePath)
if (!content) {
logger.debug('Heartbeat skipped (no heartbeat.md)', { agentId, scheduleId })
return { sessionId: null, result: 'Skipped (no file)' }
}
effectivePrompt = [
'[Heartbeat]',
'This is a periodic heartbeat. The instructions below are from your heartbeat.md file.',
'Process each item, take action where possible, and use the notify tool to alert the user of important results.',
'',
'---',
content
].join('\n')
}
// Always create a fresh session per fire. Scheduled tasks are discrete
// invocations; cross-fire session reuse would only carry stale model
// context. Persistent state lives in workspace files (heartbeat.md, etc.).
const session = await sessionService.createSession({ agentId, name: taskName ?? 'Scheduled task' })
const subscribedChannels = scheduleId ? await agentChannelService.getSubscribedChannels(scheduleId) : []
const channelManager = application.get('ChannelManager')
const channelListeners: StreamListener[] = subscribedChannels.flatMap((ch) => {
const adapter = channelManager.getAdapter(ch.id)
if (!adapter) return []
// Suppress the listener's generic `Error: …` — `notifyTaskError` below sends a richer
// `[Task failed]` summary to the same chats, so leaving it on would double-notify.
return adapter.notifyChatIds.map((chatId) => new ChannelAdapterListener(adapter, chatId, true))
})
const { signal: runSignal, dispose } = makeRunSignal(ctx.signal, timeoutMinutes)
const startTimeMs = Date.now()
let resolveExecution!: (text: string) => void
let rejectExecution!: (err: unknown) => void
const executionDone = new Promise<string>((resolve, reject) => {
resolveExecution = resolve
rejectExecution = reject
})
let accumulatedText = ''
const sentinel: StreamListener = {
id: `agent-task:${scheduleId ?? ctx.jobId}`,
onChunk(chunk) {
// `text-delta`'s field is `delta`, not `text` (AI SDK `UIMessageChunk`) — the
// previous `as { text }` cast silently never accumulated, so the persisted
// result was always the `'Completed'` fallback.
if (chunk.type === 'text-delta') accumulatedText += chunk.delta
},
onDone() {
resolveExecution(accumulatedText.trim())
},
onPaused() {
if (runSignal.aborted) {
const reason = runSignal.reason
rejectExecution(reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted')))
return
}
resolveExecution(accumulatedText.trim())
},
onError(result) {
rejectExecution(new Error(result.error.message ?? 'Execution failed'))
},
// Keep `true`: the manager prunes a listener whose `isAlive()` is false BEFORE
// firing its terminal callback, so gating on `runSignal` here would make an
// aborted run's terminal event never settle `executionDone`. Abort is handled
// explicitly via `onRunAbort` below.
isAlive: () => true
}
const topicId = buildAgentSessionTopicId(session.id)
// On JobManager cancel or per-task timeout, stop the upstream run: the execution's
// own controller never sees `runSignal`, so abort the live stream and settle
// `executionDone` here — otherwise the handler promise leaks until the JobManager's
// force-finalize timeout.
const onRunAbort = () => {
const reason = runSignal.reason
application
.get('AiStreamManager')
.abort(topicId, reason instanceof Error ? reason.message : String(reason ?? 'task-aborted'))
rejectExecution(reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted')))
}
if (runSignal.aborted) onRunAbort()
else runSignal.addEventListener('abort', onRunAbort, { once: true })
let runError: Error | null = null
let resultText = ''
try {
await startAgentSessionRun({
sessionId: session.id,
userParts: [{ type: 'text', text: effectivePrompt }],
listeners: [sentinel, ...channelListeners]
})
resultText = await executionDone
if (runSignal.aborted) {
const reason = runSignal.reason
throw reason instanceof Error ? reason : new Error(String(reason ?? 'Task aborted'))
}
} catch (err) {
runError = err instanceof Error ? err : new Error(String(err))
if (!runSignal.aborted && subscribedChannels.length > 0) {
await notifyTaskError(
{ id: scheduleId, name: taskName, durationMs: Date.now() - startTimeMs },
runError.message,
subscribedChannels
)
}
throw runError
} finally {
runSignal.removeEventListener('abort', onRunAbort)
dispose()
}
return {
sessionId: session.id,
result: resultText.slice(0, 200) || 'Completed'
}
}
async function notifyTaskError(
task: { id: string | null; name: string | null; durationMs: number },
error: string,
subscribedChannels: Array<{ id: string }>
): Promise<void> {
const channelManager = application.get('ChannelManager')
try {
const durationSec = Math.round(task.durationMs / 1000)
const label = task.name ?? task.id ?? '(unknown)'
const text = `[Task failed] ${label}\nDuration: ${durationSec}s\nError: ${error}`
for (const ch of subscribedChannels) {
const adapter = channelManager.getAdapter(ch.id)
if (!adapter) continue
for (const chatId of adapter.notifyChatIds) {
adapter.sendMessage(chatId, text).catch((err) => {
logger.warn('Failed to deliver task error notification', {
scheduleId: task.id,
channelId: ch.id,
chatId,
error: err instanceof Error ? err.message : String(err)
})
})
}
}
} catch (err) {
logger.warn('Error while building task error notification', {
scheduleId: task.id,
error: err instanceof Error ? err.message : String(err)
})
}
}

View File

@@ -1,94 +1,9 @@
import { loggerService } from '@logger'
import type { FileAttachment, ImageAttachment } from '@main/utils/downloadAsBase64'
import type { ChannelLogEntry, ChannelLogLevel, ChannelStatusEvent } from '@shared/config/types'
import { net } from 'electron'
import type { AgentChannelEntity, AgentChannelType } from '@shared/data/api/schemas/agentChannels'
import { EventEmitter } from 'events'
const logger = loggerService.withContext('ChannelAdapter')
/** Pre-downloaded, base64-encoded image ready for multimodal AI input. */
export type ImageAttachment = {
data: string // base64-encoded image bytes
media_type: string // e.g. 'image/png', 'image/jpeg', 'image/gif', 'image/webp'
}
/** Pre-downloaded, base64-encoded file attachment from an IM channel. */
export type FileAttachment = {
filename: string // original filename, e.g. 'report.pdf'
data: string // base64-encoded file bytes
media_type: string // MIME type, e.g. 'application/pdf', 'text/plain'
size: number // raw byte size (before base64 encoding)
}
/** Maximum file size we'll download from IM channels (20 MB). */
export const MAX_FILE_SIZE_BYTES = 20 * 1024 * 1024
/**
* Download an image URL via Electron's net.fetch (respects system proxy) and
* return base64-encoded data. Returns null on failure.
*/
export async function downloadImageAsBase64(url: string): Promise<ImageAttachment | null> {
try {
const response = await net.fetch(url)
if (!response.ok) {
logger.warn('Failed to download image', { url, status: response.status })
return null
}
const contentType = response.headers.get('content-type') || 'image/png'
const mediaType = contentType.split(';')[0].trim()
const buffer = Buffer.from(await response.arrayBuffer())
return { data: buffer.toString('base64'), media_type: mediaType }
} catch (error) {
logger.warn('Failed to fetch image', {
url,
error: error instanceof Error ? error.message : String(error)
})
return null
}
}
/**
* Download a file URL via Electron's net.fetch and return base64-encoded data.
* Enforces MAX_FILE_SIZE_BYTES. Returns null on failure or if the file is too large.
*/
export async function downloadFileAsBase64(url: string, filename: string): Promise<FileAttachment | null> {
try {
const response = await net.fetch(url)
if (!response.ok) {
logger.warn('Failed to download file', { url, filename, status: response.status })
return null
}
const contentLength = response.headers.get('content-length')
if (contentLength && parseInt(contentLength, 10) > MAX_FILE_SIZE_BYTES) {
logger.warn('File too large, skipping download', { filename, size: contentLength })
return null
}
const buffer = Buffer.from(await response.arrayBuffer())
if (buffer.length > MAX_FILE_SIZE_BYTES) {
logger.warn('File too large after download', { filename, size: buffer.length })
return null
}
const contentType = response.headers.get('content-type') || 'application/octet-stream'
const mediaType = contentType.split(';')[0].trim()
return {
filename,
data: buffer.toString('base64'),
media_type: mediaType,
size: buffer.length
}
} catch (error) {
logger.warn('Failed to fetch file', {
url,
filename,
error: error instanceof Error ? error.message : String(error)
})
return null
}
}
export type ChannelMessageEvent = {
chatId: string
userId: string
@@ -113,11 +28,19 @@ export type SendMessageOptions = {
replyToMessageId?: number
}
export type ChannelAdapterConfig = {
/** Channel type → its config payload, projected from the `AgentChannelEntity` discriminated union. */
type ChannelConfigByType = {
[T in AgentChannelType]: Extract<AgentChannelEntity, { type: T }>['config']
}
/** The config payload for a specific channel type. Indexing keeps `ChannelAdapterConfig` covariant in `T`. */
export type ChannelConfigForType<T extends AgentChannelType = AgentChannelType> = ChannelConfigByType[T]
export type ChannelAdapterConfig<T extends AgentChannelType = AgentChannelType> = {
channelId: string
channelType: string
channelType: T
agentId: string
channelConfig: Record<string, unknown>
channelConfig: ChannelConfigForType<T>
}
/**
@@ -139,7 +62,7 @@ export type ChannelAdapterConfig = {
*/
export abstract class ChannelAdapter extends EventEmitter {
readonly channelId: string
readonly channelType: string
readonly channelType: AgentChannelType
readonly agentId: string
/**
* Chat IDs that this adapter can send notifications/task results to.

View File

@@ -1,24 +1,34 @@
import { application } from '@application'
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
import { loggerService } from '@logger'
import { BaseService, DependsOn, Injectable, Phase, ServicePhase } from '@main/core/lifecycle'
import { WindowType } from '@main/core/window/types'
import type { ChannelLogEntry, ChannelStatusEvent } from '@shared/config/types'
import type { AgentChannelEntity as ChannelRow } from '@shared/data/api/schemas/agentChannels'
import type { AgentChannelEntity as ChannelRow, AgentChannelType } from '@shared/data/api/schemas/agentChannels'
import type { ChannelConfig } from '@shared/data/types/channel'
import { IpcChannel } from '@shared/IpcChannel'
import type { ChannelAdapter } from './ChannelAdapter'
import type { ChannelConfig } from './channelConfig'
import { ChannelLogBuffer } from './ChannelLogBuffer'
import { channelMessageHandler } from './ChannelMessageHandler'
const logger = loggerService.withContext('ChannelManager')
// Adapter factory registry -- adapters register themselves here
type AdapterFactory = (channel: ChannelRow, agentId: string) => ChannelAdapter
const adapterFactories = new Map<string, AdapterFactory>()
// Adapter factory registry -- adapters register themselves here. The factory
// for a given channel type receives the matching variant of the discriminated
// `ChannelRow` union, so `channel.config` is strongly typed per adapter.
type AdapterFactory<T extends AgentChannelType = AgentChannelType> = (
channel: Extract<ChannelRow, { type: T }>,
agentId: string
) => ChannelAdapter
const adapterFactories = new Map<AgentChannelType, AdapterFactory>()
export function registerAdapterFactory(type: string, factory: AdapterFactory): void {
adapterFactories.set(type, factory)
export function registerAdapterFactory<T extends AgentChannelType>(type: T, factory: AdapterFactory<T>): void {
// A factory is always stored under, and looked up by, its own channel type
// (see `connectChannelFromRow`), so the row handed to it is guaranteed to be
// this variant. That invariant is the one thing the type system can't see, so
// we narrow the row to the factory's variant here — nothing wider is asserted.
adapterFactories.set(type, (channel, agentId) => factory(channel as Extract<ChannelRow, { type: T }>, agentId))
}
/**
@@ -26,7 +36,7 @@ export function registerAdapterFactory(type: string, factory: AdapterFactory): v
* Each module registers itself via `registerAdapterFactory()` as a side effect.
* This avoids eagerly importing all 6 heavy adapter modules at startup.
*/
const adapterImportMap: Record<string, () => Promise<unknown>> = {
const adapterImportMap: Record<AgentChannelType, () => Promise<unknown>> = {
discord: () => import('./adapters/discord/DiscordAdapter'),
feishu: () => import('./adapters/feishu/FeishuAdapter'),
qq: () => import('./adapters/qq/QqAdapter'),
@@ -36,15 +46,15 @@ const adapterImportMap: Record<string, () => Promise<unknown>> = {
}
/** Ensure the adapter factory for the given type is loaded (idempotent). */
async function ensureAdapterLoaded(type: string): Promise<void> {
async function ensureAdapterLoaded(type: AgentChannelType): Promise<void> {
if (adapterFactories.has(type)) return
const loader = adapterImportMap[type]
if (!loader) return
await loader()
await adapterImportMap[type]()
}
class ChannelManager {
private static instance: ChannelManager | null = null
@Injectable('ChannelManager')
@ServicePhase(Phase.WhenReady)
@DependsOn(['WindowManager'])
export class ChannelManager extends BaseService {
private readonly adapters = new Map<string, ChannelAdapter>() // key: `${agentId}:${channelId}`
private readonly qrWaiters = new Map<
string,
@@ -53,11 +63,12 @@ class ChannelManager {
private readonly channelLogs = new ChannelLogBuffer()
private readonly channelStatuses = new Map<string, ChannelStatusEvent>()
static getInstance(): ChannelManager {
if (!ChannelManager.instance) {
ChannelManager.instance = new ChannelManager()
}
return ChannelManager.instance
protected async onReady(): Promise<void> {
await this.start()
}
protected async onStop(): Promise<void> {
await this.stop()
}
async start(): Promise<void> {
@@ -394,5 +405,3 @@ class ChannelManager {
}
}
}
export const channelManager = ChannelManager.getInstance()

View File

@@ -4,27 +4,22 @@ import path from 'node:path'
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
import { agentService } from '@data/services/AgentService'
import { agentSessionService as sessionService } from '@data/services/AgentSessionService'
import { sessionService } from '@data/services/SessionService'
import { loggerService } from '@logger'
import { sessionMessageOrchestrator } from '@main/services/agents/services/SessionMessageOrchestrator'
import type { GetAgentSessionResponse, PermissionMode } from '@types'
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
import { isAgentSessionWorkspaceError } from '@main/ai/runtime/claudeCode/settingsBuilder'
import { ChannelAdapterListener, type StreamListener } from '@main/ai/streamManager'
import { startAgentSessionRun } from '@main/ai/streamManager/api/startAgentSessionRun'
import { application } from '@main/core/application'
import type { FileAttachment, ImageAttachment } from '@main/utils/downloadAsBase64'
import type { AgentSessionEntity } from '@shared/data/api/schemas/sessions'
import { sanitizeChannelOutput, wrapExternalContent } from '../security'
import type {
ChannelAdapter,
ChannelCommandEvent,
ChannelMessageEvent,
FileAttachment,
ImageAttachment
} from './ChannelAdapter'
import type { ChannelAdapter, ChannelCommandEvent, ChannelMessageEvent } from './ChannelAdapter'
import { SLASH_COMMANDS } from './constants'
import { sessionStreamBus } from './SessionStreamBus'
import { broadcastSessionChanged } from './sessionStreamIpc'
import { splitMessage } from './utils'
import { wrapExternalContent } from './security'
const logger = loggerService.withContext('ChannelMessageHandler')
const MAX_MESSAGE_LENGTH = 4096
const TYPING_INTERVAL_MS = 4000
/** Max number of entries in the session tracker before evicting oldest entries. */
@@ -54,7 +49,7 @@ export class ChannelMessageHandler {
private static instance: ChannelMessageHandler | null = null
// TODO: in v2 use cacheService
private readonly sessionTracker = new Map<string, string>() // `${agentId}:${channelId}:${chatId}` -> sessionId
private readonly pendingResolutions = new Map<string, Promise<GetAgentSessionResponse | null>>()
private readonly pendingResolutions = new Map<string, Promise<AgentSessionEntity | null>>()
/** Per-chat debounce buffer — accumulates rapid messages before flushing */
private readonly pendingBatches = new Map<string, PendingBatch>()
/** Per-chat serial queue — ensures only one stream runs at a time per chat */
@@ -192,12 +187,24 @@ export class ChannelMessageHandler {
return
}
// Apply channel-level permission mode override on every message (not just session creation).
// This ensures changes to the channel's permission_mode take effect immediately,
// even for sessions created before the setting was changed.
await this.applyChannelPermissionMode(session, adapter.channelId)
// Resolve agent for cognitive config (model / configuration / mcps / allowedTools).
// Workspace is read from the session itself (CMA Environment binding).
// An orphan session (`agentId === null`) cannot run; skip it.
if (!session.agentId) {
logger.error('Channel message hit an orphan session', { sessionId: session.id })
return
}
const agent = await agentService.getAgent(session.agentId)
if (!agent) {
logger.error('Agent not found for session', { sessionId: session.id, agentId: session.agentId })
return
}
const workDir = session.accessiblePaths[0]
// TODO(channel-perm-override): channel-level permission_mode used to mutate
// session.configuration in-place; with config now living on agent, this
// override needs to flow as a per-dispatch option instead. Tracked separately.
const workDir = session.workspace?.path
// Save images to agent workspace so the agent can read them via the Read tool
let imagePaths: string[] = []
@@ -252,36 +259,6 @@ export class ChannelMessageHandler {
channelType: adapter.channelType
})
// Build display text: append filenames so the user can see them in the UI
let displayText = message.text
if (message.files && message.files.length > 0) {
const names = message.files.map((f) => `📎 ${f.filename}`).join('\n')
displayText = displayText ? `${displayText}\n${names}` : names
}
// Snapshot subscriber state ONCE — this single check drives:
// 1. Whether user-message is published to the renderer
// 2. The persist flag (renderer persistence vs headless persistence)
// 3. Whether stream chunks / complete events are forwarded
// Checking once eliminates the race where subscribe() IPC completes
// between the user-message publish and the persist decision.
const rendererIsWatching = sessionStreamBus.hasSubscribers(session.id)
if (rendererIsWatching) {
sessionStreamBus.publish(session.id, {
sessionId: session.id,
agentId: session.agentId,
type: 'user-message',
userMessage: {
chatId: message.chatId,
userId: message.userId,
userName: message.userName,
text: displayText,
images: message.images
}
})
}
const abortController = new AbortController()
this.activeAbortControllers.set(session.id, abortController)
@@ -293,30 +270,22 @@ export class ChannelMessageHandler {
)
try {
const responseText = await this.collectStreamResponse(
session,
securedContent,
abortController,
adapter,
message.chatId,
message.text,
message.images,
rendererIsWatching
)
if (responseText) {
// Sanitize output to prevent accidental secret leakage through channels
const { text: sanitizedText } = sanitizeChannelOutput(responseText)
const finalized = await adapter.onStreamComplete(message.chatId, sanitizedText).catch(() => false)
if (!finalized) {
await this.sendChunked(adapter, message.chatId, sanitizedText)
}
}
// Delivery (streaming updates + the sanitized finalize) is owned by the
// `ChannelAdapterListener` registered inside `collectStreamResponse`; we only await
// turn completion here. (The old post-hoc finalize was dead — the sentinel's `c.text`
// read never accumulated — and reviving it would double-send.)
await this.collectStreamResponse(session, securedContent, abortController, adapter, message.chatId)
} catch (streamError) {
// Notify adapter of the error so it can update streaming UI
adapter
.onStreamError(message.chatId, streamError instanceof Error ? streamError.message : String(streamError))
.catch(() => {})
const streamErrorMessage = streamError instanceof Error ? streamError.message : String(streamError)
if (isAgentSessionWorkspaceError(streamError)) {
// Thrown before streaming starts (validateSession), so no controller exists yet and
// onStreamError is a no-op on most adapters — send a plain message so the inbound
// message isn't silently dropped on Telegram/WeChat/QQ/Discord/Slack.
adapter.sendMessage(message.chatId, streamErrorMessage).catch(() => {})
} else {
// Mid-stream error: let the adapter update its streaming UI.
adapter.onStreamError(message.chatId, streamErrorMessage).catch(() => {})
}
throw streamError
} finally {
this.activeAbortControllers.delete(session.id)
@@ -336,28 +305,14 @@ export class ChannelMessageHandler {
try {
switch (command.command) {
case 'new': {
const agent = await agentService.getAgent(agentId)
const channelRow = await channelService.getChannel(adapter.channelId)
const permMode = channelRow?.permissionMode as PermissionMode | undefined
const newSession = await sessionService.createSession(agentId, {
...(agent?.configuration
? {
configuration: {
...agent.configuration,
...(permMode ? { permission_mode: permMode } : {})
}
}
: {})
})
if (newSession) {
// Update channel's session_id to point to the new session
await channelService.updateChannel(adapter.channelId, { sessionId: newSession.id })
const trackerKey = `${agentId}:${adapter.channelId}:${command.chatId}`
this.sessionTracker.set(trackerKey, newSession.id)
this.evictSessionTracker()
await adapter.sendMessage(command.chatId, 'New session created.')
}
// TODO(channel-perm-override): channel.permissionMode no longer
// applied here — config lives on agent now. Tracked separately.
const newSession = await sessionService.createSession({ agentId, name: 'Channel session' })
await channelService.updateChannel(adapter.channelId, { sessionId: newSession.id })
const trackerKey = `${agentId}:${adapter.channelId}:${command.chatId}`
this.sessionTracker.set(trackerKey, newSession.id)
this.evictSessionTracker()
await adapter.sendMessage(command.chatId, 'New session created.')
break
}
case 'compact': {
@@ -380,7 +335,12 @@ export class ChannelMessageHandler {
adapter,
command.chatId
)
await adapter.sendMessage(command.chatId, response || 'Session compacted.')
// The `ChannelAdapterListener` registered inside `collectStreamResponse` already
// delivered any non-empty output; only send an explicit fallback when compact
// produced no text, so we don't double-send.
if (!response) {
await adapter.sendMessage(command.chatId, 'Session compacted.')
}
} finally {
clearInterval(typingInterval)
}
@@ -431,22 +391,6 @@ export class ChannelMessageHandler {
}
}
/**
* Look up the channel's current permission_mode from the agent config and
* override the session's configuration in-place. This ensures that changes
* to the channel permission mode take effect immediately even for sessions
* that were created before the setting was changed.
*/
private async applyChannelPermissionMode(session: GetAgentSessionResponse, channelId: string): Promise<void> {
const channel = await channelService.getChannel(channelId)
if (channel?.permissionMode && session.configuration) {
session.configuration = {
...session.configuration,
permission_mode: channel.permissionMode
}
}
}
/** Evict oldest session tracker entries when the map exceeds the size limit. */
private evictSessionTracker(): void {
if (this.sessionTracker.size <= SESSION_TRACKER_MAX_SIZE) return
@@ -460,15 +404,27 @@ export class ChannelMessageHandler {
/** Clear session tracking for an agent (used when agent is deleted/updated) */
clearSessionTracker(agentId: string): void {
for (const key of this.sessionTracker.keys()) {
// Abort any in-flight stream owned by a tracked session of this agent
// before dropping the tracker entries — otherwise the stream keeps
// running on a deleted agent and `sendMessage` to a now-detached
// channel will throw.
const sessionIdsToAbort: string[] = []
for (const [key, sessionId] of this.sessionTracker.entries()) {
if (key.startsWith(`${agentId}:`)) {
sessionIdsToAbort.push(sessionId)
this.sessionTracker.delete(key)
}
}
for (const sessionId of sessionIdsToAbort) {
this.abortSessionStream(sessionId, 'agent-cleared')
}
for (const [key, batch] of this.pendingBatches.entries()) {
if (key.startsWith(`${agentId}:`)) {
clearTimeout(batch.timer)
this.pendingBatches.delete(key)
// Settle the discarded batch's callers so their .catch handlers fire
// instead of leaving handleIncoming promises hanging forever.
batch.resolvers.forEach((r) => r.reject(new Error('Agent removed; batch discarded')))
}
}
for (const key of this.chatQueues.keys()) {
@@ -478,14 +434,22 @@ export class ChannelMessageHandler {
}
}
/** Abort an active stream for the given session. Returns true if aborted. */
/** Abort an active stream for the given session. Returns true if a stream was in flight. */
abortSession(sessionId: string): boolean {
const controller = this.activeAbortControllers.get(sessionId)
if (controller) {
controller.abort()
return true
}
return false
if (!this.activeAbortControllers.has(sessionId)) return false
this.abortSessionStream(sessionId, 'channel-session-aborted')
return true
}
/**
* Stop the upstream agent-session turn for a session. The local `AbortController`
* is never passed to the running stream it only flips a listener's `isAlive()`,
* which (because the manager prunes dead listeners before firing their terminal
* callback) would strand the completion sentinel. So abort through the manager,
* which settles the turn as `paused` and lets the still-alive sentinel resolve.
*/
private abortSessionStream(sessionId: string, reason: string): void {
application.get('AiStreamManager').abort(buildAgentSessionTopicId(sessionId), reason)
}
private async resolveSession(
@@ -493,7 +457,7 @@ export class ChannelMessageHandler {
channelId: string,
channelType: string,
chatId: string
): Promise<GetAgentSessionResponse | null> {
): Promise<AgentSessionEntity | null> {
const trackerKey = `${agentId}:${channelId}:${chatId}`
// Coalesce concurrent resolutions for the same chat to avoid duplicate sessions
@@ -515,15 +479,15 @@ export class ChannelMessageHandler {
_channelType: string,
_chatId: string,
trackerKey: string
): Promise<GetAgentSessionResponse | null> {
): Promise<AgentSessionEntity | null> {
const channelRow = await channelService.getChannel(channelId)
const lookup = async (sessionId: string) => sessionService.getById(sessionId).catch(() => null)
// Check tracker first
const trackedId = this.sessionTracker.get(trackerKey)
if (trackedId) {
const session = await sessionService.getSession(agentId, trackedId)
if (session) {
// Ensure channel's session_id stays in sync
const session = await lookup(trackedId)
if (session && session.agentId === agentId) {
if (channelRow && channelRow.sessionId !== session.id) {
channelService
.updateChannel(channelId, { sessionId: session.id })
@@ -531,19 +495,18 @@ export class ChannelMessageHandler {
logger.warn('Failed to sync channel-session link', err instanceof Error ? err : new Error(String(err)))
)
}
return session as GetAgentSessionResponse
return session
}
// Tracked session gone, clear it
this.sessionTracker.delete(trackerKey)
}
// Look up existing session via channel's session_id
if (channelRow?.sessionId) {
const existingSession = await sessionService.getSession(agentId, channelRow.sessionId)
if (existingSession) {
const existingSession = await lookup(channelRow.sessionId)
if (existingSession && existingSession.agentId === agentId) {
this.sessionTracker.set(trackerKey, existingSession.id)
this.evictSessionTracker()
return existingSession as GetAgentSessionResponse
return existingSession
}
}
@@ -554,136 +517,57 @@ export class ChannelMessageHandler {
channelSessionId: channelRow?.sessionId ?? null,
trackerKey
})
const agent = await agentService.getAgent(agentId)
const channelPermissionMode = channelRow?.permissionMode as PermissionMode | undefined
const newSession = await sessionService.createSession(agentId, {
...(agent?.configuration
? {
configuration: {
...agent.configuration,
...(channelPermissionMode ? { permission_mode: channelPermissionMode } : {})
}
}
: {})
})
if (newSession) {
// Link channel to the new session
await channelService.updateChannel(channelId, { sessionId: newSession.id })
this.sessionTracker.set(trackerKey, newSession.id)
this.evictSessionTracker()
return newSession as GetAgentSessionResponse
}
return null
const newSession = await sessionService.createSession({ agentId, name: 'Channel session' })
await channelService.updateChannel(channelId, { sessionId: newSession.id })
this.sessionTracker.set(trackerKey, newSession.id)
this.evictSessionTracker()
return newSession
}
private async collectStreamResponse(
session: GetAgentSessionResponse,
session: AgentSessionEntity,
content: string,
abortController: AbortController,
adapter: ChannelAdapter,
chatId: string,
displayContent?: string,
images?: ImageAttachment[],
rendererIsWatching: boolean = false
chatId: string
): Promise<string> {
// Use the pre-computed rendererIsWatching flag from processIncoming.
// When renderer is watching: persist=false (renderer handles rich block persistence),
// stream chunks and events are forwarded to the renderer via the bus.
// When renderer is NOT watching: persist=true (main persists via persistHeadlessExchange),
// stream events are NOT forwarded (no subscriber or subscriber arrived late).
const { stream, completion } = await sessionMessageOrchestrator.createSessionMessage(
session,
{ content },
abortController,
{ persist: !rendererIsWatching, displayContent, images }
)
const reader = stream.getReader()
let completedText = '' // text from finished blocks/turns
let currentBlockText = '' // cumulative text within the current block
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
// Only forward chunks to renderer when it was confirmed watching at stream start.
// This prevents late-subscribing renderers from receiving partial chunks
// while main process is also persisting (which would cause duplicates).
if (rendererIsWatching) {
sessionStreamBus.publish(session.id, {
sessionId: session.id,
agentId: session.agentId,
type: 'chunk',
chunk: value
})
}
// Skip user message echoes — only accumulate assistant text for the channel reply
const rawType = (value as any).providerMetadata?.raw?.type
if (rawType === 'user') continue
switch (value.type) {
case 'text-delta':
// text-delta values are cumulative within a block
if (value.text) {
currentBlockText = value.text
// Notify adapter of text update — adapter owns its own throttle/flush
const fullText = completedText + currentBlockText
adapter.onTextUpdate(chatId, fullText).catch(() => {})
}
break
case 'text-end':
// Block finished — commit current block text and reset for next turn
if (currentBlockText) {
completedText += currentBlockText + '\n\n'
currentBlockText = ''
}
break
}
}
await completion
if (rendererIsWatching) {
// Notify renderer that stream is complete and data is persisted
sessionStreamBus.publish(session.id, {
sessionId: session.id,
agentId: session.agentId,
type: 'complete'
})
}
// headless=true means main process persisted; renderer should force-reload from DB.
// headless=false means renderer handled persistence; no reload needed.
broadcastSessionChanged(session.agentId, session.id, !rendererIsWatching)
// Trim trailing separator
return (completedText + currentBlockText).replace(/\n+$/, '')
} catch (error) {
if (rendererIsWatching) {
sessionStreamBus.publish(session.id, {
sessionId: session.id,
agentId: session.agentId,
type: 'error',
error: { message: error instanceof Error ? error.message : String(error) }
})
}
throw error
}
}
private async sendChunked(adapter: ChannelAdapter, chatId: string, text: string): Promise<void> {
if (text.length <= MAX_MESSAGE_LENGTH) {
await adapter.sendMessage(chatId, text)
return
if (!session.agentId) {
throw new Error(`Cannot stream on orphan session ${session.id} — its agent was deleted`)
}
const chunks = splitMessage(text, MAX_MESSAGE_LENGTH)
for (const chunk of chunks) {
await adapter.sendMessage(chatId, chunk)
let resolveExecution!: (text: string) => void
let rejectExecution!: (err: unknown) => void
const executionDone = new Promise<string>((resolve, reject) => {
resolveExecution = resolve
rejectExecution = reject
})
let accumulatedText = ''
const sentinel: StreamListener = {
id: `channel-completion:${chatId}`,
onChunk(chunk) {
// `text-delta`'s field is `delta`, not `text` (AI SDK `UIMessageChunk`).
if (chunk.type === 'text-delta') accumulatedText += chunk.delta
},
onDone() {
resolveExecution(accumulatedText.trim())
},
onPaused() {
resolveExecution(accumulatedText.trim())
},
onError(result) {
rejectExecution(new Error(result.error.message ?? 'Execution failed'))
},
isAlive: () => !abortController.signal.aborted
}
await startAgentSessionRun({
sessionId: session.id,
userParts: [{ type: 'text', text: content }],
listeners: [sentinel, new ChannelAdapterListener(adapter, chatId)]
})
return executionDone
}
/**

View File

@@ -2,9 +2,11 @@ import { agentChannelService as channelService } from '@data/services/AgentChann
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { ChannelAdapter, type ChannelAdapterConfig } from '../ChannelAdapter'
import { channelManager, registerAdapterFactory } from '../ChannelManager'
import { ChannelManager, registerAdapterFactory } from '../ChannelManager'
import { channelMessageHandler } from '../ChannelMessageHandler'
const channelManager = new ChannelManager()
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })

View File

@@ -0,0 +1,512 @@
import { agentChannelService as channelService } from '@data/services/AgentChannelService'
import { agentService } from '@data/services/AgentService'
import { sessionService } from '@data/services/SessionService'
import { buildAgentSessionTopicId } from '@main/ai/agentSession/topic'
import { AgentSessionWorkspaceError } from '@main/ai/runtime/claudeCode/settingsBuilder'
import { EventEmitter } from 'events'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { channelMessageHandler } from '../ChannelMessageHandler'
import { sanitizeChannelOutput } from '../security'
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })
}
}))
vi.mock('../security', () => ({
wrapExternalContent: vi.fn((text: string) => text),
sanitizeChannelOutput: vi.fn((text: string) => ({ text, redacted: false }))
}))
// The global mock (tests/main.setup.ts) wires the default service set, which omits
// AiStreamManager; the abort path reads it, so override locally with a captured spy.
const { mockStreamAbort } = vi.hoisted(() => ({ mockStreamAbort: vi.fn() }))
vi.mock('@application', async () => {
const { mockApplicationFactory } = await import('@test-mocks/main/application')
return mockApplicationFactory({ AiStreamManager: { abort: mockStreamAbort } } as never)
})
vi.mock('@data/services/AgentService', () => ({
agentService: {
getAgent: vi.fn().mockResolvedValue({
id: 'agent-1',
configuration: {},
model: 'openai::gpt-4'
})
}
}))
vi.mock('@data/services/SessionService', () => ({
sessionService: {
getById: vi.fn(),
createSession: vi.fn()
}
}))
vi.mock('@shared/data/types/model', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
createUniqueModelId: vi.fn((providerId: string, modelId: string) => `${providerId}::${modelId}`)
}
})
const { mockStartAgentSessionRun } = vi.hoisted(() => ({ mockStartAgentSessionRun: vi.fn() }))
vi.mock('@main/ai/streamManager/api/startAgentSessionRun', () => ({
startAgentSessionRun: (...args: unknown[]) => mockStartAgentSessionRun(...args)
}))
vi.mock('@data/services/AgentChannelService', () => ({
agentChannelService: {
getChannel: vi.fn().mockResolvedValue({ id: 'channel-1', sessionId: null, permissionMode: null }),
updateChannel: vi.fn().mockResolvedValue(null),
findBySessionId: vi.fn().mockResolvedValue(null)
}
}))
/**
* Helper: configure mockStartAgentSessionRun to simulate streaming chunks to ALL
* registered listeners (both the `channel-completion:` sentinel and the
* `ChannelAdapterListener` that owns delivery), then call onDone on each so the
* `executionDone` promise inside `collectStreamResponse` resolves and the listener
* finalizes delivery. `text-delta` chunks carry the payload on `delta` (AI SDK
* `UIMessageChunk`), not `text`.
*/
function simulateStream(parts: Array<{ type: string; delta?: string }>) {
mockStartAgentSessionRun.mockImplementationOnce(
async ({
listeners
}: {
listeners: Array<{
id: string
onChunk: (chunk: unknown) => void
onDone: (result: { status: string }) => void | Promise<void>
}>
}) => {
for (const listener of listeners) {
for (const part of parts) {
listener.onChunk(part)
}
await listener.onDone({ status: 'success' })
}
}
)
}
function createMockAdapter(overrides: Record<string, unknown> = {}) {
const adapter = new EventEmitter() as any
adapter.agentId = overrides.agentId ?? 'agent-1'
adapter.channelId = overrides.channelId ?? 'channel-1'
adapter.channelType = overrides.channelType ?? 'telegram'
adapter.connected = true
adapter.sendMessage = vi.fn().mockResolvedValue(undefined)
adapter.sendTypingIndicator = vi.fn().mockResolvedValue(undefined)
adapter.onTextUpdate = vi.fn().mockResolvedValue(undefined)
adapter.onStreamComplete = vi.fn().mockResolvedValue(false)
adapter.onStreamError = vi.fn().mockResolvedValue(undefined)
adapter.notifyChatIds = []
return adapter
}
/**
* Helper: call handleIncoming and advance fake timers so the debounce fires,
* then await the returned promise to wait for processing to complete.
*/
async function handleIncomingAndFlush(
adapter: ReturnType<typeof createMockAdapter>,
message: { chatId: string; userId: string; userName: string; text: string }
) {
const promise = channelMessageHandler.handleIncoming(adapter, message)
// Advance past the MESSAGE_BATCH_DELAY_MS debounce (10 000 ms)
await vi.advanceTimersByTimeAsync(10500)
return promise
}
describe('ChannelMessageHandler', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.clearAllMocks()
// Restore default agent mock after clearAllMocks
vi.mocked(agentService.getAgent).mockResolvedValue({
id: 'agent-1',
configuration: {},
model: 'openai::gpt-4'
} as any)
// Clear session tracker to ensure clean state
channelMessageHandler.clearSessionTracker('agent-1')
})
afterEach(() => {
vi.useRealTimers()
})
it('collectStreamResponse accumulates text across turns and sends via adapter', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
simulateStream([
{ type: 'text-delta', delta: 'Hello ' },
{ type: 'text-delta', delta: 'world!' },
{ type: 'text-end' },
{ type: 'text-delta', delta: '\n\nDone.' },
{ type: 'text-end' }
])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
// Delivery is owned by ChannelAdapterListener (the handler no longer post-sends);
// it accumulates all text-delta chunks via `.delta`, trims, and sends once.
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Hello world!\n\nDone.')
})
// channels-core-3: the streaming delivery path (real ChannelAdapterListener) must route
// output through the OutputSanitizer before sending — otherwise secrets in the model reply
// leak to the IM platform. simulateStream drives the real listener, so a redacting sanitizer
// must be reflected in what the adapter sends.
it('routes channel output through the OutputSanitizer before delivery (REGRESSION channels-core-3)', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
vi.mocked(sanitizeChannelOutput).mockImplementation((text: string) => ({
text: text.replace('sk-SECRET', '<redacted>'),
redacted: text.includes('sk-SECRET')
}))
simulateStream([{ type: 'text-delta', delta: 'the key is sk-SECRET' }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
expect(sanitizeChannelOutput).toHaveBeenCalled()
// The redacted text — not the raw secret — is what reaches the adapter.
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'the key is <redacted>')
// Restore the identity default so later tests are unaffected.
vi.mocked(sanitizeChannelOutput).mockImplementation((text: string) => ({ text, redacted: false }))
})
// stream-context-5: a workspace error is thrown before streaming starts, so onStreamError
// (a no-op without a live controller on most adapters) can't surface it. The handler must
// fall back to a plain sendMessage so the inbound message isn't silently dropped.
it('surfaces a pre-stream workspace error as a plain message (REGRESSION stream-context-5)', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
mockStartAgentSessionRun.mockRejectedValueOnce(new AgentSessionWorkspaceError('workspace is missing'))
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'workspace is missing')
expect(adapter.onStreamError).not.toHaveBeenCalled()
})
it('skips final send when adapter handles stream completion', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
adapter.onStreamComplete.mockResolvedValueOnce(true)
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
simulateStream([{ type: 'text-delta', delta: 'Hello world!' }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
expect(adapter.onStreamComplete).toHaveBeenCalledWith('chat-1', 'Hello world!')
expect(adapter.sendMessage).not.toHaveBeenCalled()
})
it('delivers a long response in a single send (platform splitting is the adapter concern)', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
const longText = 'A'.repeat(5000)
simulateStream([{ type: 'text-delta', delta: longText }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
// The handler-level 4096-char chunking was dead code (post-hoc path never ran)
// and has been removed; ChannelAdapterListener delivers the full text once and
// each adapter splits per its own platform limit.
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', longText)
})
it('handleCommand /new creates a new session', async () => {
const adapter = createMockAdapter()
vi.mocked(sessionService.createSession).mockResolvedValueOnce({ id: 'new-session' } as any)
await channelMessageHandler.handleCommand(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
command: 'new'
})
expect(sessionService.createSession).toHaveBeenCalledWith({
agentId: 'agent-1',
name: 'Channel session'
})
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'New session created.')
})
it('handleCommand /compact sends /compact as message content', async () => {
const adapter = createMockAdapter()
const session = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session as any)
simulateStream([{ type: 'text-delta', delta: 'Compacted.' }])
await channelMessageHandler.handleCommand(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
command: 'compact'
})
expect(mockStartAgentSessionRun).toHaveBeenCalledWith(
expect.objectContaining({
sessionId: 'session-1',
userParts: [{ type: 'text', text: '/compact' }],
listeners: expect.arrayContaining([
expect.objectContaining({ id: expect.stringContaining('channel-completion:') })
])
})
)
// ChannelAdapterListener delivers the compact output once; the handler no longer
// also sends it (would have been a double-send once the `.delta` read was fixed).
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
expect(adapter.sendMessage).toHaveBeenCalledWith('chat-1', 'Compacted.')
})
it('handleCommand /help sends help text with agent info', async () => {
const adapter = createMockAdapter()
vi.mocked(agentService.getAgent).mockResolvedValueOnce({
name: 'TestAgent',
description: 'A test agent'
} as any)
await channelMessageHandler.handleCommand(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
command: 'help'
})
expect(adapter.sendMessage).toHaveBeenCalledTimes(1)
const helpText = adapter.sendMessage.mock.calls[0][1] as string
expect(helpText).toContain('*TestAgent*')
expect(helpText).toContain('_A test agent_')
expect(helpText).toContain('/new')
expect(helpText).toContain('/compact')
expect(helpText).toContain('/help')
expect(helpText).toContain('/whoami')
})
it('handleCommand /whoami sends the current chat ID', async () => {
const adapter = createMockAdapter()
await channelMessageHandler.handleCommand(adapter, {
chatId: 'oc_123',
userId: 'user-1',
userName: 'User',
command: 'whoami'
})
expect(adapter.sendMessage).toHaveBeenCalledWith(
'oc_123',
'Current chat ID: `oc_123`\n\nAdd this value to `allow_ids` in settings to receive notifications.'
)
})
it('resolveSession tracks sessions after /new', async () => {
const adapter = createMockAdapter()
const newSession = {
id: 'new-session',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
vi.mocked(sessionService.createSession).mockResolvedValueOnce(newSession as any)
await channelMessageHandler.handleCommand(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
command: 'new'
})
// Now send a message — should use the tracked session
vi.mocked(sessionService.getById).mockResolvedValueOnce(newSession as any)
simulateStream([{ type: 'text-delta', delta: 'OK' }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'test'
})
expect(sessionService.getById).toHaveBeenCalledWith('new-session')
})
it('clearSessionTracker causes fresh session resolution', async () => {
const adapter = createMockAdapter()
const session1 = {
id: 'session-1',
agentId: 'agent-1',
agentType: 'claude-code',
model: 'openai::gpt-4',
workspace: { path: '/tmp/test-workspace' },
configuration: {}
}
// First interaction creates a session
vi.mocked(sessionService.createSession).mockResolvedValueOnce(session1 as any)
simulateStream([{ type: 'text-delta', delta: 'R1' }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'msg1'
})
// Clear session tracker
channelMessageHandler.clearSessionTracker('agent-1')
// Next interaction should find existing session via channel's session_id
vi.mocked(channelService.getChannel).mockResolvedValueOnce({
id: 'channel-1',
sessionId: 'session-1',
permissionMode: null
} as any)
vi.mocked(sessionService.getById).mockResolvedValueOnce(session1 as any)
simulateStream([{ type: 'text-delta', delta: 'R2' }])
await handleIncomingAndFlush(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'msg2'
})
// After clearing tracker, should look up channel then getSession instead of creating new session
expect(channelService.getChannel).toHaveBeenCalledWith('channel-1')
// Only 1 createSession call (the first one), not 2
expect(sessionService.createSession).toHaveBeenCalledTimes(1)
})
// channels-core-3: discarding a pending (un-flushed) batch must settle its callers'
// handleIncoming promises instead of leaving them hanging forever, so .catch fires.
it('clearSessionTracker rejects pending-batch handleIncoming promises', async () => {
const adapter = createMockAdapter()
// Start a batch but do NOT advance timers — it stays pending in pendingBatches.
const pending = channelMessageHandler.handleIncoming(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
text: 'Hi'
})
const rejection = expect(pending).rejects.toThrow('Agent removed; batch discarded')
// Clearing the agent's tracker discards the pending batch.
channelMessageHandler.clearSessionTracker('agent-1')
await rejection
expect(mockStartAgentSessionRun).not.toHaveBeenCalled()
})
// channels-core-2: a local AbortController only flips a listener's isAlive() — clearing
// a tracked session must stop the upstream agent-session turn via the manager.
it('clearSessionTracker aborts the upstream agent-session turn via the manager', async () => {
const adapter = createMockAdapter()
vi.mocked(sessionService.createSession).mockResolvedValueOnce({ id: 'sess-x' } as any)
await channelMessageHandler.handleCommand(adapter, {
chatId: 'chat-1',
userId: 'user-1',
userName: 'User',
command: 'new'
})
mockStreamAbort.mockClear()
channelMessageHandler.clearSessionTracker('agent-1')
expect(mockStreamAbort).toHaveBeenCalledWith(buildAgentSessionTopicId('sess-x'), 'agent-cleared')
})
})

View File

@@ -0,0 +1,79 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({ info: vi.fn(), error: vi.fn(), warn: vi.fn(), debug: vi.fn(), silly: vi.fn() })
}
}))
vi.mock('../../ChannelManager', () => ({
registerAdapterFactory: vi.fn()
}))
const mockNetFetch = vi.fn()
vi.mock('electron', () => ({
app: { getPath: () => '/mock/userData' },
net: { fetch: (...args: unknown[]) => mockNetFetch(...args) }
}))
vi.mock('ws', () => {
const Ctor = vi.fn()
Object.assign(Ctor, { OPEN: 1, CONNECTING: 0, CLOSED: 3, CLOSING: 2 })
return { default: Ctor, WebSocket: Ctor }
})
import '../qq/QqAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
// Capture the factory at module load — `registerAdapterFactory('qq', …)` runs once on import,
// and afterEach's restoreAllMocks would otherwise wipe that call history before later tests.
const qqCall = vi.mocked(registerAdapterFactory).mock.calls.find((c) => c[0] === 'qq')
if (!qqCall) throw new Error('registerAdapterFactory was not called for qq')
const qqFactory = qqCall[1] as (channel: any, agentId: string) => any
function mockBinaryResponse(buf: Buffer, contentType = 'image/png'): Response {
return {
ok: true,
status: 200,
headers: new Headers({ 'content-type': contentType }),
arrayBuffer: () => Promise.resolve(buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength))
} as unknown as Response
}
function createAdapter() {
return qqFactory(
{ id: 'ch-qq-1', type: 'qq', enabled: true, config: { app_id: 'app', client_secret: 'sec', allowed_chat_ids: [] } },
'agent-1'
)
}
describe('QqAdapter.downloadAttachments', () => {
beforeEach(() => mockNetFetch.mockReset())
afterEach(() => vi.restoreAllMocks())
it('rejects an SSRF target before any (token-bearing) fetch (C8)', async () => {
const adapter = createAdapter()
vi.spyOn(adapter, 'getAccessToken').mockResolvedValue('tok')
const result = await adapter.downloadAttachments([
{ url: 'http://169.254.169.254/latest/meta-data/', content_type: 'image/png', filename: 'meta' }
])
expect(result).toEqual({})
expect(mockNetFetch).not.toHaveBeenCalled()
})
it('downloads a public attachment URL', async () => {
const adapter = createAdapter()
vi.spyOn(adapter, 'getAccessToken').mockResolvedValue('tok')
mockNetFetch.mockResolvedValue(mockBinaryResponse(Buffer.from('img'), 'image/png'))
const result = await adapter.downloadAttachments([
{ url: 'https://gchat.qpic.cn/a.png', content_type: 'image/png', filename: 'a.png' }
])
expect(result.images).toHaveLength(1)
expect(mockNetFetch).toHaveBeenCalled()
})
})

View File

@@ -349,6 +349,50 @@ describe('SlackAdapter', () => {
expect(messageSpy.mock.calls[0][0].images).toHaveLength(1)
})
it('skips an oversized image download (content-length over the cap)', async () => {
const adapter = await connectAdapter()
const messageSpy = vi.fn()
adapter.on('message', messageSpy)
const oversize = {
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'image/png', 'content-length': String(500 * 1024 * 1024) }),
json: () => Promise.reject(new Error('not json')),
text: () => Promise.resolve(''),
arrayBuffer: () => Promise.resolve(Buffer.from('x').buffer)
} as unknown as Response
mockNetFetch.mockImplementation((url: string) => {
if (typeof url === 'string' && url.includes('files.slack.com')) {
return Promise.resolve(oversize)
}
if (url.includes('users.info')) {
return Promise.resolve(mockJsonResponse({ ok: true, user: { real_name: 'Test User' } }))
}
return Promise.resolve(mockJsonResponse({ ok: true }))
})
simulateMessageEvent({
channel: 'C0ALLOWED',
user: USER1_ID,
text: 'see attached',
subtype: 'file_share',
files: [
{
id: 'F1',
name: 'huge.png',
mimetype: 'image/png',
size: 1000,
url_private: 'https://files.slack.com/huge.png'
}
]
})
await vi.waitFor(() => expect(messageSpy).toHaveBeenCalledTimes(1))
expect(messageSpy.mock.calls[0][0].images).toBeUndefined()
})
it('ignores messages from the bot itself', async () => {
const adapter = await connectAdapter()
const messageSpy = vi.fn()

View File

@@ -99,6 +99,55 @@ describe('TelegramAdapter', () => {
expect(mockBot.stop).toHaveBeenCalledTimes(1)
})
// channel-adapters-2: grammY rethrows a fatal 409/Conflict out of bot.start(); the adapter
// must reconnect with backoff instead of staying permanently down.
it('reconnects with backoff when polling rejects (REGRESSION channel-adapters-2)', async () => {
vi.useFakeTimers()
const adapter = createAdapter()
mockBot.start.mockReset()
// First polling attempt fails (recoverable 409); the reconnect attempt succeeds.
mockBot.start.mockRejectedValueOnce(new Error('409: Conflict')).mockResolvedValue(undefined)
await adapter.connect()
await vi.advanceTimersByTimeAsync(0) // let the rejection handler schedule the reconnect
expect(mockBot.start).toHaveBeenCalledTimes(1)
await vi.advanceTimersByTimeAsync(1000) // first backoff delay
expect(mockBot.start).toHaveBeenCalledTimes(2) // reconnected
})
it('resets the reconnect budget after a stable polling window (REGRESSION channel-adapters-2)', async () => {
vi.useFakeTimers()
const adapter = createAdapter()
mockBot.start.mockReset()
// One transient failure bumps the attempt counter, then the reconnect stays up.
mockBot.start.mockRejectedValueOnce(new Error('409: Conflict')).mockResolvedValue(undefined)
await adapter.connect()
await vi.advanceTimersByTimeAsync(1000) // reconnect fires and succeeds
expect(adapter.reconnectAttempts).toBe(1)
// After the stability window the counter resets, so lifetime-cumulative transient
// failures can't monotonically exhaust maxReconnectAttempts.
await vi.advanceTimersByTimeAsync(60_000)
expect(adapter.reconnectAttempts).toBe(0)
})
it('does not reconnect after disconnect() (REGRESSION channel-adapters-2)', async () => {
vi.useFakeTimers()
const adapter = createAdapter()
mockBot.start.mockReset()
mockBot.start.mockRejectedValue(new Error('409: Conflict'))
await adapter.connect()
await vi.advanceTimersByTimeAsync(0) // a reconnect is now pending
await adapter.disconnect() // shouldStop + clear the pending reconnect timer
const callsAfterDisconnect = mockBot.start.mock.calls.length
await vi.advanceTimersByTimeAsync(60_000)
expect(mockBot.start.mock.calls.length).toBe(callsAfterDisconnect) // no further reconnect
})
it('sendMessage() sends text with MarkdownV2 by default', async () => {
const adapter = createAdapter()
await adapter.connect()

View File

@@ -1,16 +1,14 @@
import { net } from 'electron'
import WebSocket from 'ws'
import {
ChannelAdapter,
type ChannelAdapterConfig,
downloadFileAsBase64,
downloadImageAsBase64,
type FileAttachment,
type ImageAttachment,
MAX_FILE_SIZE_BYTES,
type SendMessageOptions
} from '../../ChannelAdapter'
MAX_FILE_SIZE_BYTES
} from '@main/utils/downloadAsBase64'
import { net } from 'electron'
import WebSocket from 'ws'
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
import { isSlashCommand, SLASH_COMMANDS } from '../../constants'
import { FlushController } from '../../FlushController'
@@ -234,12 +232,11 @@ class DiscordAdapter extends ChannelAdapter {
/** Per-chat streaming controller. One stream at a time per chat. */
private readonly streamingControllers = new Map<string, DiscordStreamingController>()
constructor(config: ChannelAdapterConfig) {
constructor(config: ChannelAdapterConfig<'discord'>) {
super(config)
const { bot_token, allowed_channel_ids } = config.channelConfig
this.botToken = (bot_token as string) ?? ''
const rawIds = allowed_channel_ids as string[] | undefined
this.allowedChannelIds = Array.isArray(rawIds) ? rawIds.map(String) : []
this.botToken = bot_token
this.allowedChannelIds = allowed_channel_ids ?? []
this.notifyChatIds = [...this.allowedChannelIds]
}

View File

@@ -4,17 +4,11 @@ import type { ReadableStream as NodeReadableStream } from 'node:stream/web'
import { application } from '@application'
import * as Lark from '@larksuiteoapi/node-sdk'
import { WindowType } from '@main/core/window/types'
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
import type { FeishuDomain } from '@shared/data/types/channel'
import { IpcChannel } from '@shared/IpcChannel'
import {
ChannelAdapter,
type ChannelAdapterConfig,
type FileAttachment,
type ImageAttachment,
MAX_FILE_SIZE_BYTES,
type SendMessageOptions
} from '../../ChannelAdapter'
import type { FeishuDomain } from '../../channelConfig'
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
import { isSlashCommand } from '../../constants'
import { FlushController } from '../../FlushController'
@@ -476,16 +470,15 @@ class FeishuAdapter extends ChannelAdapter {
/** Active status reaction per chat, so we can swap or remove it. */
private readonly chatReactions = new Map<string, ChatReaction>()
constructor(config: ChannelAdapterConfig) {
constructor(config: ChannelAdapterConfig<'feishu'>) {
super(config)
const { app_id, app_secret, encrypt_key, verification_token, allowed_chat_ids, domain } = config.channelConfig
this.appId = (app_id as string) ?? ''
this.appSecret = (app_secret as string) ?? ''
this.encryptKey = (encrypt_key as string) ?? ''
this.verificationToken = (verification_token as string) ?? ''
const rawIds = allowed_chat_ids as string[] | undefined
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
this.domain = ((domain as string) ?? 'feishu') as FeishuDomain
this.appId = app_id
this.appSecret = app_secret
this.encryptKey = encrypt_key
this.verificationToken = verification_token
this.allowedChatIds = allowed_chat_ids ?? []
this.domain = domain
this.notifyChatIds = [...this.allowedChatIds]
}

View File

@@ -7,10 +7,9 @@
* Flow: init -> begin (returns QR URL) -> poll (returns client_id + client_secret)
*/
import { loggerService } from '@logger'
import type { FeishuDomain } from '@shared/data/types/channel'
import { net } from 'electron'
import type { FeishuDomain } from '../../channelConfig'
const logger = loggerService.withContext('FeishuAppRegistration')
const BASE_URLS: Record<FeishuDomain, string> = {

View File

@@ -1,14 +1,9 @@
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
import { sanitizeRemoteUrl } from '@main/utils/remoteUrlSafety'
import { net } from 'electron'
import WebSocket from 'ws'
import {
ChannelAdapter,
type ChannelAdapterConfig,
type FileAttachment,
type ImageAttachment,
MAX_FILE_SIZE_BYTES,
type SendMessageOptions
} from '../../ChannelAdapter'
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
import { isSlashCommand } from '../../constants'
import { splitMessage } from '../../utils'
@@ -88,13 +83,12 @@ class QqAdapter extends ChannelAdapter {
/** Number of rapid disconnects before invalidating session */
private readonly maxRapidDisconnects = 3
constructor(config: ChannelAdapterConfig) {
constructor(config: ChannelAdapterConfig<'qq'>) {
super(config)
const { app_id, client_secret, allowed_chat_ids } = config.channelConfig
this.appId = (app_id as string) ?? ''
this.clientSecret = (client_secret as string) ?? ''
const rawIds = allowed_chat_ids as string[] | undefined
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
this.appId = app_id
this.clientSecret = client_secret
this.allowedChatIds = allowed_chat_ids ?? []
this.notifyChatIds = [...this.allowedChatIds]
}
@@ -240,7 +234,7 @@ class QqAdapter extends ChannelAdapter {
return
}
if (payload.s !== undefined) {
if (payload.s !== undefined && payload.s !== null) {
this.lastSeq = payload.s
}
@@ -427,17 +421,23 @@ class QqAdapter extends ChannelAdapter {
.map(async (att) => {
try {
const url = att.url.startsWith('http') ? att.url : `https://${att.url}`
const response = await net.fetch(url, {
// SSRF guard: reject local/private/credentialed/non-http(s) targets from the
// inbound payload before we fetch with the bot token (and before the retry).
const safeUrl = sanitizeRemoteUrl(url)
const response = await net.fetch(safeUrl, {
headers: { Authorization: `QQBot ${token}`, 'X-Union-Appid': this.appId }
})
if (!response.ok) {
// Retry without auth header (some CDN URLs are public)
const retry = await net.fetch(url)
const retry = await net.fetch(safeUrl)
if (!retry.ok) return
const buffer = Buffer.from(await retry.arrayBuffer())
// `att.size` is attacker-supplied metadata; cap on the real downloaded bytes.
if (buffer.length > MAX_FILE_SIZE_BYTES) return
this.pushAttachment(att, buffer, images, files)
} else {
const buffer = Buffer.from(await response.arrayBuffer())
if (buffer.length > MAX_FILE_SIZE_BYTES) return
this.pushAttachment(att, buffer, images, files)
}
} catch {

View File

@@ -1,14 +1,8 @@
import { type FileAttachment, type ImageAttachment, MAX_FILE_SIZE_BYTES } from '@main/utils/downloadAsBase64'
import { net } from 'electron'
import WebSocket from 'ws'
import {
ChannelAdapter,
type ChannelAdapterConfig,
type FileAttachment,
type ImageAttachment,
MAX_FILE_SIZE_BYTES,
type SendMessageOptions
} from '../../ChannelAdapter'
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
import { isSlashCommand } from '../../constants'
import { FlushController } from '../../FlushController'
@@ -211,13 +205,12 @@ class SlackAdapter extends ChannelAdapter {
/** Track the latest incoming message ts per chatId for reaction acknowledgment */
private readonly pendingReactions = new Map<string, string>()
constructor(config: ChannelAdapterConfig) {
constructor(config: ChannelAdapterConfig<'slack'>) {
super(config)
const { bot_token, app_token, allowed_channel_ids } = config.channelConfig
this.botToken = (bot_token as string) ?? ''
this.appToken = (app_token as string) ?? ''
const rawIds = allowed_channel_ids as string[] | undefined
this.allowedChannelIds = Array.isArray(rawIds) ? rawIds.map(String) : []
this.botToken = bot_token
this.appToken = app_token
this.allowedChannelIds = allowed_channel_ids ?? []
this.notifyChatIds = [...this.allowedChannelIds]
}
@@ -483,9 +476,12 @@ class SlackAdapter extends ChannelAdapter {
headers: { Authorization: `Bearer ${this.botToken}` }
})
if (!response.ok) return null
const contentLength = response.headers.get('content-length')
if (contentLength && Number.parseInt(contentLength, 10) > MAX_FILE_SIZE_BYTES) return null
const contentType = response.headers.get('content-type') || 'image/png'
const mediaType = contentType.split(';')[0].trim()
const buffer = Buffer.from(await response.arrayBuffer())
if (buffer.length > MAX_FILE_SIZE_BYTES) return null
return { data: buffer.toString('base64'), media_type: mediaType }
} catch {
return null

View File

@@ -1,19 +1,26 @@
import { Bot } from 'grammy'
import { convert as toMarkdownV2 } from 'telegram-markdown-v2'
import {
ChannelAdapter,
type ChannelAdapterConfig,
downloadFileAsBase64,
downloadImageAsBase64,
type FileAttachment,
type ImageAttachment,
MAX_FILE_SIZE_BYTES,
type SendMessageOptions
} from '../../ChannelAdapter'
MAX_FILE_SIZE_BYTES
} from '@main/utils/downloadAsBase64'
import { Bot } from 'grammy'
import { convert as toMarkdownV2 } from 'telegram-markdown-v2'
import { ChannelAdapter, type ChannelAdapterConfig, type SendMessageOptions } from '../../ChannelAdapter'
import { registerAdapterFactory } from '../../ChannelManager'
const TELEGRAM_MAX_LENGTH = 4096
/**
* Plain-text chunk budget under MarkdownV2. We split the *plain* text (so each
* chunk has an index-aligned plain fallback) and then escape it; escaping only
* grows length, so this headroom keeps the formatted chunk within the 4096 hard
* limit for normal prose. A pathological all-special-char chunk could still
* overflow Telegram then rejects it and the catch sends the plain chunk, which
* is always within budget.
*/
const TELEGRAM_MARKDOWN_CHUNK_BUDGET = 3200
import { splitMessage } from '../../utils'
@@ -22,12 +29,26 @@ class TelegramAdapter extends ChannelAdapter {
private readonly botToken: string
private readonly allowedChatIds: string[]
constructor(config: ChannelAdapterConfig) {
// Long-polling reconnect with backoff. grammY rethrows fatal polling errors (401/409) out of
// `bot.start()`; without this a recoverable 409/Conflict left the bot permanently down.
// Mirrors the WebSocket adapters (Slack/Discord/QQ).
private shouldStop = false
private reconnectAttempts = 0
private reconnectTimer: ReturnType<typeof setTimeout> | null = null
private stabilityTimer: ReturnType<typeof setTimeout> | null = null
private readonly reconnectDelays = [1000, 2000, 5000, 10000, 30000, 60000]
private readonly maxReconnectAttempts = 50
// Long polling has no "ready" event (a fatal 409 surfaces *after* `markConnected`), so
// resetting the backoff budget on connect would let a persistent failure loop forever and
// never hit the cap. Instead reset only after the bot has polled cleanly for this window —
// so transient failures spread over the adapter's lifetime don't monotonically exhaust it.
private readonly stabilityResetMs = 60_000
constructor(config: ChannelAdapterConfig<'telegram'>) {
super(config)
const { bot_token, allowed_chat_ids } = config.channelConfig
this.botToken = (bot_token as string) ?? ''
const rawIds = allowed_chat_ids as string[] | undefined
this.allowedChatIds = Array.isArray(rawIds) ? rawIds.map(String) : []
this.botToken = bot_token
this.allowedChatIds = allowed_chat_ids ?? []
this.notifyChatIds = [...this.allowedChatIds]
}
@@ -39,7 +60,12 @@ class TelegramAdapter extends ChannelAdapter {
if (!this.botToken) {
throw new Error('Telegram bot token is required')
}
this.shouldStop = false
this.reconnectAttempts = 0
await this.startBot()
}
private async startBot(): Promise<void> {
const bot = new Bot(this.botToken)
this.bot = bot
@@ -161,18 +187,70 @@ class TelegramAdapter extends ChannelAdapter {
this.log.error(`Bot error: ${msg}`)
})
// Start long polling (fire-and-forget)
// Start long polling (fire-and-forget). `bot.start()` only resolves when the bot stops;
// a fatal polling error (e.g. 409 Conflict) rejects here — schedule a backoff reconnect.
bot.start().catch((err) => {
const msg = err instanceof Error ? err.message : String(err)
this.clearStabilityTimer()
this.markDisconnected(msg)
this.log.error(`Polling stopped: ${msg}`)
this.scheduleReconnect()
})
this.markConnected()
this.log.info('Telegram bot polling started')
// Reset the reconnect budget once this connection has stayed up for the stability window.
// The `this.bot === bot` guard ensures a stale timer from a superseded connection no-ops.
this.clearStabilityTimer()
this.stabilityTimer = setTimeout(() => {
this.stabilityTimer = null
if (!this.shouldStop && this.bot === bot) this.reconnectAttempts = 0
}, this.stabilityResetMs)
}
private clearStabilityTimer(): void {
if (this.stabilityTimer) {
clearTimeout(this.stabilityTimer)
this.stabilityTimer = null
}
}
private scheduleReconnect(): void {
if (this.shouldStop || this.reconnectTimer) return
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
this.log.error('Telegram max reconnect attempts reached, giving up')
return
}
const delay = this.reconnectDelays[Math.min(this.reconnectAttempts, this.reconnectDelays.length - 1)]
this.reconnectAttempts++
this.log.info(`Telegram reconnecting in ${delay}ms (attempt ${this.reconnectAttempts})`)
this.reconnectTimer = setTimeout(() => {
this.reconnectTimer = null
if (this.shouldStop) return
this.startBot().catch((err) => {
const msg = err instanceof Error ? err.message : String(err)
this.markDisconnected(msg)
this.log.error(`Telegram reconnect failed: ${msg}`)
this.scheduleReconnect()
})
}, delay)
}
private clearReconnectTimer(): void {
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer)
this.reconnectTimer = null
}
}
protected override async performDisconnect(): Promise<void> {
this.shouldStop = true
this.clearReconnectTimer()
this.clearStabilityTimer()
if (this.bot) {
await this.bot.stop()
this.bot = null
@@ -228,33 +306,39 @@ class TelegramAdapter extends ChannelAdapter {
}
const parseMode = opts?.parseMode ?? 'MarkdownV2'
const formatted = parseMode === 'MarkdownV2' ? toMarkdownV2(text).trimEnd() : text
const chunks = splitMessage(formatted, TELEGRAM_MAX_LENGTH)
const isMarkdown = parseMode === 'MarkdownV2'
// Split the PLAIN text first and escape each chunk, so the MarkdownV2 send and
// its plain-text fallback share one chunk boundary. (Splitting the *formatted*
// text and then re-splitting the *raw* text by the same index misaligns — escaping
// changes lengths/boundaries — dropping, duplicating, or passing `undefined` chunks.)
const plainChunks = splitMessage(text, isMarkdown ? TELEGRAM_MARKDOWN_CHUNK_BUDGET : TELEGRAM_MAX_LENGTH)
for (let i = 0; i < chunks.length; i++) {
for (let i = 0; i < plainChunks.length; i++) {
const plain = plainChunks[i]
const formatted = isMarkdown ? toMarkdownV2(plain).trimEnd() : plain
const replyParams =
opts?.replyToMessageId && i === 0 ? { reply_parameters: { message_id: opts.replyToMessageId } } : {}
try {
await this.bot.api.sendMessage(chatId, chunks[i], {
await this.bot.api.sendMessage(chatId, formatted, {
parse_mode: parseMode,
...replyParams
})
} catch (error) {
// Fallback to plain text if MarkdownV2 parsing fails
if (parseMode === 'MarkdownV2') {
// Fallback to plain text if MarkdownV2 parsing fails — same chunk content.
if (isMarkdown) {
this.log.warn('MarkdownV2 send failed, falling back to plain text', {
chatId,
error: error instanceof Error ? error.message : String(error)
})
await this.bot.api.sendMessage(chatId, splitMessage(text, TELEGRAM_MAX_LENGTH)[i], replyParams)
await this.bot.api.sendMessage(chatId, plain, replyParams)
} else {
throw error
}
}
// Small delay between chunks to avoid rate limiting
if (i < chunks.length - 1) {
if (i < plainChunks.length - 1) {
await new Promise((resolve) => setTimeout(resolve, 100))
}
}

Some files were not shown because too many files have changed in this diff Show More