Merge branch 'main' into fix/image-generation-empty-panel

This commit is contained in:
Gu JiaMing
2026-07-02 19:34:20 +08:00
committed by GitHub
79 changed files with 2215 additions and 958 deletions

View File

@@ -305,6 +305,28 @@ describe('runAgentTask', () => {
})
})
// Regular tasks carry the workspace bound at creation time (system by
// default, since the picker defaults there) straight through to the session.
it('binds a non-heartbeat task to the workspace bound on the task', async () => {
vi.mocked(jobService.getById).mockReturnValueOnce(makeJobSnapshot())
vi.mocked(jobScheduleService.getById).mockReturnValueOnce(makeSchedule('daily-summary'))
vi.mocked(agentService.getAgent).mockReturnValueOnce(makeAgent())
vi.mocked(agentSessionService.create).mockReturnValueOnce(makeSession('/ws/a'))
const promise = runAgentTask(
makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0, workspace: { type: 'system' } } })
)
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
captured.listeners[0].onDone({ status: 'completed' })
await promise
expect(agentSessionService.create).toHaveBeenCalledWith({
agentId: 'a1',
name: 'daily-summary',
workspace: { type: 'system' }
})
})
// C1 (agents-jobs-3): a `text-delta` chunk's payload is on `.delta`, not `.text`.
// The previous `as { text }` cast silently accumulated nothing, so every run
// persisted the `'Completed'` fallback instead of the model's reply.

View File

@@ -138,6 +138,9 @@ export async function runAgentTask(ctx: JobContext<AgentTaskInput>): Promise<Age
// Always create a fresh session per fire. Scheduled tasks are discrete
// invocations; cross-fire session reuse would only carry stale model
// context. Persistent state lives in workspace files (heartbeat.md, etc.).
// The session inherits the workspace bound on the task at creation time —
// system for regular tasks (the picker defaults there), the validated user
// workspace for heartbeats.
const session = agentSessionService.create({
agentId,
name: taskName ?? 'Scheduled task',

View File

@@ -9,9 +9,14 @@
* mutation itself lives in the shared `knowledgeLookup` core so the Claude Code
* MCP bridge runs identical logic (gated there by Claude Code's own permission
* prompt); this file is just the AI-SDK `tool()` wrapper.
*
* `defer: 'never'` (kept inline, never behind `tool_search`/`tool_invoke`): the same rule
* `mcp/mcpTools.ts` applies to force-prompt MCP tools. Deferring an approval-gated tool would strip
* it from the SDK's tool-set, so the SDK's native `needsApproval` gate never fires and `tool_invoke`
* refuses it too (it never runs an approval-gated tool blind) — an unreachable tool either way.
*/
import { KB_MANAGE_TOOL_NAME, kbManageInputSchema, kbManageOutputSchema } from '@shared/ai/builtinTools'
import { KB_MANAGE_TOOL_NAME, kbManageOutputSchema, kbManageStrictInputSchema } from '@shared/ai/builtinTools'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import * as z from 'zod'
@@ -31,7 +36,7 @@ const knowledgeManageResultSchema = z.union([kbManageOutputSchema, knowledgeLook
const kbManageTool = tool({
description: KNOWLEDGE_MANAGE_DESCRIPTION,
inputSchema: kbManageInputSchema,
inputSchema: kbManageStrictInputSchema,
outputSchema: knowledgeManageResultSchema,
strict: true,
// Every action (add / delete / refresh) modifies the base; gate on explicit user approval.
@@ -48,7 +53,7 @@ export function createKbManageToolEntry(): ToolEntry {
name: KB_MANAGE_TOOL_NAME,
namespace: 'kb',
description: 'Add, delete, or re-index documents in a knowledge base (requires approval)',
defer: 'always',
defer: 'never',
tool: kbManageTool,
applies: (scope) => scope.hasAnyKnowledgeBase === true && (scope.assistant?.knowledgeBaseIds?.length ?? 0) > 0
}

View File

@@ -69,7 +69,10 @@ describe('kb_manage', () => {
it('builds an entry with the agreed namespace + defer policy and is approval-gated', () => {
expect(entry.name).toBe(KB_MANAGE_TOOL_NAME)
expect(entry.namespace).toBe('kb')
expect(entry.defer).toBe('always')
// Approval-gated tools must stay inline (never deferred) — see applyDeferExposition/toolInvoke:
// a deferred approval-gated tool is unreachable (stripped from the inline set, and tool_invoke
// refuses to run an approval-gated tool blind).
expect(entry.defer).toBe('never')
// Every action mutates the base, so the tool must require user approval.
expect(entry.tool.needsApproval).toBe(true)
})

View File

@@ -34,4 +34,20 @@ describe('registerBuiltinTools', () => {
expect(readFile?.applies?.({ mcpToolIds: new Set(), hasFileAttachments: false })).toBe(false)
expect(readFile?.applies?.({ mcpToolIds: new Set(), hasFileAttachments: true })).toBe(true)
})
it(
'never defers an approval-gated entry (would strip it from the inline set with no way back — ' +
'see mcp/mcpTools.ts and toolInvoke.ts for the same rule on MCP force-prompt tools)',
() => {
const reg = new ToolRegistry()
registerBuiltinTools(reg)
for (const entry of reg.getAll()) {
if (entry.tool.needsApproval) {
expect(entry.defer).toBe('never')
}
}
// Sanity: this loop is only meaningful while at least one builtin entry is approval-gated.
expect(reg.getAll().some((e) => e.tool.needsApproval)).toBe(true)
}
)
})

View File

@@ -25,7 +25,6 @@ import type {
KbGrepOutput,
KbListOutput,
KbListOutputItem,
KbManageInput,
KbManageOutput,
KbReadInput,
KbReadOutput,
@@ -415,6 +414,18 @@ export async function listOrOutlineKnowledge(
/** Longest a derived note title (its first line) may be before it is truncated. */
const NOTE_TITLE_MAX_CHARS = 80
/** kb_manage input shape shared by both callers: MCP omits an unused field, AI-SDK strict passes null. */
type ManageKnowledgeInput = {
baseId: string
action: 'add' | 'delete' | 'refresh'
type?: 'file' | 'url' | 'note' | null
path?: string | null
url?: string | null
content?: string | null
title?: string | null
conceptIds?: string[] | null
}
/**
* Apply a destructive knowledge-base change (add / delete / refresh). Like the
* read cores it never throws: an out-of-scope base, a missing required field, an
@@ -425,7 +436,7 @@ const NOTE_TITLE_MAX_CHARS = 80
* executes the mutation unconditionally once invoked.
*/
export async function manageKnowledge(
input: KbManageInput,
input: ManageKnowledgeInput,
allowedIds: string[]
): Promise<KnowledgeManageResultOrError> {
if (allowedIds.length > 0 && !allowedIds.includes(input.baseId)) {
@@ -493,7 +504,7 @@ type AddInputResult = { ok: true; input: KnowledgeAddItemInput; source: string }
* invalid value (e.g. a non-absolute file path) is rejected before it reaches the
* filesystem boundary. `source` is the identifier reported back as `added`.
*/
function buildAddInput(input: KbManageInput): AddInputResult {
function buildAddInput(input: ManageKnowledgeInput): AddInputResult {
switch (input.type) {
case 'file': {
if (!input.path) {
@@ -540,7 +551,7 @@ function firstNonEmptyLine(content: string): string | undefined {
}
/** A note's display source: the caller-supplied title, else its first non-empty line (truncated), else a placeholder. */
function deriveNoteSource(content: string, title?: string): string {
function deriveNoteSource(content: string, title?: string | null): string {
const explicit = title?.trim()
if (explicit) return explicit
// Truncation here differs by role from deriveSampleSource's note branch (a stored id, plain-clipped;

View File

@@ -982,10 +982,10 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
// fresh uuid dir, the runtime never opens a store mid-migration, and the catch below wipes a
// partial on a caught failure. A crash-orphaned dir is never referenced by a knowledge_base row
// so it is never mounted (it is dead disk, the same as the rename path produced).
const driver = await openBetterSqlite3IndexDriver(plan.targetDbPath)
const driver = openBetterSqlite3IndexDriver(plan.targetDbPath)
try {
await createKnowledgeIndexSchema(driver)
await ensureIndexMeta(driver, { baseId: plan.baseId })
createKnowledgeIndexSchema(driver)
ensureIndexMeta(driver, { baseId: plan.baseId })
const store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
for (const material of plan.materials) {
@@ -998,11 +998,11 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
// Fold the WAL back into the main db file so the committed pages are durable in index.sqlite
// itself (the WAL is not guaranteed to be checkpointed on close); the runtime then opens a
// self-contained store.
await driver.execute('PRAGMA wal_checkpoint(TRUNCATE)')
driver.execute('PRAGMA wal_checkpoint(TRUNCATE)')
} finally {
// Close so the file handle is released (a leaked handle would block a re-run's
// removeIndexStoreFiles and the later base-dir deletion on Windows).
await driver.close()
driver.close()
}
// Build + close succeeded: the complete store sits at its runtime path. A later failure
// (snapshot-pin) leaves it present and searchable, so it must NOT be wiped or marked failed.
@@ -1166,11 +1166,11 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
// busy_timeout=5000, so this re-read of the just-built store waits out a transient Windows lock
// (Defender / indexer scanning the freshly-written file) instead of throwing SQLITE_BUSY /
// EACCES and failing validation for an already-correct store.
const driver = await openBetterSqlite3IndexDriver(plan.targetDbPath)
const driver = openBetterSqlite3IndexDriver(plan.targetDbPath)
try {
const materialCount = await this.tableCount(driver, 'material')
const unitCount = await this.tableCount(driver, 'search_unit')
const embeddingCount = await this.tableCount(driver, 'embedding')
const materialCount = this.tableCount(driver, 'material')
const unitCount = this.tableCount(driver, 'search_unit')
const embeddingCount = this.tableCount(driver, 'embedding')
targetCount += unitCount
this.pushCountMismatch(errors, plan.baseId, 'material', plan.materials.length, materialCount)
@@ -1180,7 +1180,7 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
// Every unit's body search_text must resolve to a stored embedding, or that
// unit is silently absent from vector search. This is the migration-time
// form of the rebuild self-heal invariant (knowledge-technical-design.md §10).
const uncovered = await driver.execute(
const uncovered = driver.execute(
`SELECT count(*) AS count FROM search_text st
LEFT JOIN embedding e ON e.embedding_text_hash = st.embedding_text_hash
WHERE e.embedding_text_hash IS NULL`
@@ -1195,7 +1195,7 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
})
}
} finally {
await driver.close()
driver.close()
}
}
@@ -1234,8 +1234,8 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
}
}
private async tableCount(driver: BetterSqlite3Driver, table: string): Promise<number> {
const result = await driver.execute(`SELECT count(*) AS count FROM ${table}`)
private tableCount(driver: BetterSqlite3Driver, table: string): number {
const result = driver.execute(`SELECT count(*) AS count FROM ${table}`)
return Number(result.rows[0]?.count ?? 0)
}

View File

@@ -59,10 +59,11 @@ export function createIndexDocumentsJobHandler(
const { base, item } = input
// Mark reading before file/network IO so the UI reflects the current long-running phase.
// No base mutation lock: this only writes the main app DB (knowledge_item), not the
// per-base index.sqlite the lock protects, and updateStatus's own 'deleting' guard
// (KnowledgeItemService.updateStatus) already covers the race the lock would.
reportKnowledgeProgress(ctx, 0, { stage: 'reading', currentFile: 0, totalFiles: 1 })
await knowledgeLockManager.withBaseMutationLock(ctx.input.baseId, () => {
knowledgeItemService.updateStatus(ctx.input.itemId, 'reading')
})
knowledgeItemService.updateStatus(ctx.input.itemId, 'reading')
// Capture a url's or note's snapshot on first index (a url fetches outside
// the lock, a note writes its in-hand content; both persist a relativePath
@@ -83,10 +84,9 @@ export function createIndexDocumentsJobHandler(
}
// Mark embedding separately so the UI reflects the current long-running phase.
// No base mutation lock here either — same reasoning as the 'reading' status above.
reportKnowledgeProgress(ctx, 40, { stage: 'embedding', currentFile: 0, totalFiles: 1 })
await knowledgeLockManager.withBaseMutationLock(ctx.input.baseId, () =>
knowledgeItemService.updateStatus(ctx.input.itemId, 'embedding')
)
knowledgeItemService.updateStatus(ctx.input.itemId, 'embedding')
// Use readableItem, not item: for a freshly captured url it carries the snapshot
// relativePath, so the material's relative_path is the real `raw/` snapshot path

View File

@@ -16,12 +16,14 @@ export async function deleteKnowledgeItemVectors(base: KnowledgeBase, itemIds: s
return
}
// Delete every id in ONE batched transaction with a single collectIndexGarbage pass.
// Delete every id with a single collectIndexGarbage pass at the end (each material's
// row delete is its own short transaction; see KnowledgeIndexStore.deleteMaterials).
// The old per-id Promise.allSettled loop ran the two full-table GC scans once per item,
// so deleting a folder of N files scanned the whole embedding/content table N times —
// the multi-second main-process freeze on large (PDF-heavy) folders. deleteMaterials
// rolls the whole batch back on failure (throwing the root cause), so a retry
// re-discovers every affected id; no per-item failure aggregation is needed.
// the multi-second main-process freeze on large (PDF-heavy) folders. A failure partway
// leaves whatever was already deleted committed; that is safe because this call always
// precedes the knowledge_item DB row deletion (see subtreePurge.ts), so a retry
// re-discovers exactly the materials still left.
await store.deleteMaterials(uniqueItemIds)
}

View File

@@ -43,12 +43,12 @@ export function getKnowledgeBaseMetaDir(baseId: string): FilePath {
return path.join(getKnowledgeBaseDir(baseId), CHERRY_META_DIR) as FilePath
}
export async function getKnowledgeVectorStoreFilePath(baseId: string): Promise<FilePath> {
const metaDir = getKnowledgeBaseMetaDir(baseId)
await ensureDir(metaDir)
return getKnowledgeVectorStoreFilePathSync(baseId)
}
/**
* The base's index.sqlite path. No `ensureDir` here — `openBetterSqlite3IndexDriver`
* (BetterSqlite3Driver.ts) already `mkdirSync`s the parent `.cherry/` dir itself before
* opening, so a caller that only needs the path (not a guaranteed-existing directory)
* does not need to duplicate that work.
*/
export function getKnowledgeVectorStoreFilePathSync(baseId: string): FilePath {
const metaDir = getKnowledgeBaseMetaDir(baseId)
return path.join(metaDir, VECTOR_STORE_FILE) as FilePath

View File

@@ -8,11 +8,7 @@ import type { CompletedKnowledgeBase, KnowledgeBase } from '@shared/data/types/k
import { isCompletedKnowledgeBase } from '@shared/data/types/knowledge'
import { isIndexableKnowledgeItem } from '../utils/items'
import {
deleteKnowledgeBaseDir,
getKnowledgeVectorStoreFilePath,
getKnowledgeVectorStoreFilePathSync
} from '../utils/storage/pathStorage'
import { deleteKnowledgeBaseDir, getKnowledgeVectorStoreFilePathSync } from '../utils/storage/pathStorage'
import { openBetterSqlite3IndexDriver } from './indexStore/BetterSqlite3Driver'
import { betterSqlite3VectorIndex } from './indexStore/BetterSqlite3VectorIndex'
import { ensureIndexMeta, hasAnyMaterial, readIndexSchemaVersion } from './indexStore/indexMeta'
@@ -47,11 +43,11 @@ function assertVectorStoreReadyBase(base: KnowledgeBase): asserts base is Comple
@Injectable('KnowledgeVectorStoreService')
@ServicePhase(Phase.WhenReady)
export class KnowledgeVectorStoreService extends BaseService {
// Caches the in-flight open promise, not the resolved store, so concurrent
// getIndexStore calls for the same base share one open (one better-sqlite3
// connection) instead of racing — the loser of a "resolve then set" race would
// otherwise leak an orphaned store that no one ever closes.
private instanceCache = new Map<string, Promise<KnowledgeIndexStore>>()
// Opening a store (better-sqlite3 connect + schema + meta) is fully synchronous
// (see openIndexStore), so it runs to completion in one JS turn — no concurrent
// getIndexStore call for the same base can ever observe an in-flight open, and a
// failed open never gets cached (the throw happens before .set() below runs).
private instanceCache = new Map<string, KnowledgeIndexStore>()
/** Open (or reuse) the base's index store, ensuring its schema exists. */
async getIndexStore(base: KnowledgeBase): Promise<KnowledgeIndexStore> {
@@ -63,20 +59,10 @@ export class KnowledgeVectorStoreService extends BaseService {
return cached
}
const opening = this.openIndexStore(base)
this.instanceCache.set(base.id, opening)
try {
const store = await opening
logger.info('Opened knowledge index store', { baseId: base.id, cacheSize: this.instanceCache.size })
return store
} catch (error) {
// Evict the rejected promise so a later call retries the open instead of
// forever re-awaiting the same failure (only if it is still the cached one).
if (this.instanceCache.get(base.id) === opening) {
this.instanceCache.delete(base.id)
}
throw error
}
const store = this.openIndexStore(base)
this.instanceCache.set(base.id, store)
logger.info('Opened knowledge index store', { baseId: base.id, cacheSize: this.instanceCache.size })
return store
}
/** Reuse or open the store only if its file already exists on disk; used by cleanup paths that must not create one. */
@@ -105,12 +91,12 @@ export class KnowledgeVectorStoreService extends BaseService {
* and `index.sqlite` alike. Only safe when deleting the whole base.
*/
async deleteStore(baseId: string): Promise<void> {
const opening = this.instanceCache.get(baseId)
const store = this.instanceCache.get(baseId)
try {
await this.closeStoreInstance(opening)
await this.closeStoreInstance(store)
await deleteKnowledgeBaseDir(baseId)
logger.info('Deleted knowledge index store', { baseId, hadCachedStore: Boolean(opening) })
logger.info('Deleted knowledge index store', { baseId, hadCachedStore: Boolean(store) })
} finally {
this.instanceCache.delete(baseId)
}
@@ -121,9 +107,9 @@ export class KnowledgeVectorStoreService extends BaseService {
logger.info('Stopping knowledge index stores', { storeCount })
try {
for (const [baseId, opening] of this.instanceCache.entries()) {
for (const [baseId, store] of this.instanceCache.entries()) {
try {
await this.closeStoreInstance(opening)
await this.closeStoreInstance(store)
} catch (error) {
logger.error('Failed to close knowledge index store', error as Error, { baseId })
}
@@ -134,9 +120,9 @@ export class KnowledgeVectorStoreService extends BaseService {
}
}
private async openIndexStore(base: CompletedKnowledgeBase): Promise<KnowledgeIndexStore> {
const dbPath = await getKnowledgeVectorStoreFilePath(base.id)
const driver = await openBetterSqlite3IndexDriver(dbPath)
private openIndexStore(base: CompletedKnowledgeBase): KnowledgeIndexStore {
const dbPath = getKnowledgeVectorStoreFilePathSync(base.id)
const driver = openBetterSqlite3IndexDriver(dbPath)
try {
// An index.sqlite from an older schema layout cannot be migrated in place —
// `CREATE ... IF NOT EXISTS` never retrofits a new column/trigger onto an
@@ -147,28 +133,28 @@ export class KnowledgeVectorStoreService extends BaseService {
// (A stale-version file swapped in from another base is rebuilt here rather than
// refused by the base_id check below — but the reset drops its rows, so no other
// base's data is ever served; only the explicit refusal diagnostic is skipped.)
const storedVersion = await readIndexSchemaVersion(driver)
const storedVersion = readIndexSchemaVersion(driver)
if (storedVersion !== null && storedVersion !== KNOWLEDGE_INDEX_SCHEMA_VERSION) {
logger.warn('Knowledge index schema version mismatch — rebuilding the derived index', {
baseId: base.id,
storedVersion,
expectedVersion: KNOWLEDGE_INDEX_SCHEMA_VERSION
})
await resetKnowledgeIndexSchema(driver)
resetKnowledgeIndexSchema(driver)
} else {
await createKnowledgeIndexSchema(driver)
createKnowledgeIndexSchema(driver)
}
// Stamp + verify the meta identity row before handing out the store,
// so an index.sqlite swapped in from another base is rejected here (§4.1).
// That is the only refusal — a blank/recreated file is stamped as fresh and
// mounts empty; reportInvisibleIndexContents below makes that state loud.
await ensureIndexMeta(driver, { baseId: base.id })
await this.reportInvisibleIndexContents(driver, base.id)
ensureIndexMeta(driver, { baseId: base.id })
this.reportInvisibleIndexContents(driver, base.id)
return new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
} catch (error) {
// Close the driver opened above so a failed open never leaks the index file
// handle (which on Windows would later block deleting the base dir).
await driver.close()
driver.close()
throw error
}
}
@@ -186,8 +172,8 @@ export class KnowledgeVectorStoreService extends BaseService {
* NOT_FOUND is what turns that into a loud failure instead of a cached
* forever-empty store).
*/
private async reportInvisibleIndexContents(driver: SqliteDriver, baseId: string): Promise<void> {
if (await hasAnyMaterial(driver)) {
private reportInvisibleIndexContents(driver: SqliteDriver, baseId: string): void {
if (hasAnyMaterial(driver)) {
return
}
@@ -212,14 +198,7 @@ export class KnowledgeVectorStoreService extends BaseService {
}
}
private async closeStoreInstance(opening: Promise<KnowledgeIndexStore> | undefined): Promise<void> {
if (!opening) {
return
}
// A store that never opened needs no close (the open path already closed its
// driver on failure) — swallow the rejection here instead of re-throwing the
// open error into an unrelated delete/shutdown operation.
const store = await opening.catch(() => undefined)
private async closeStoreInstance(store: KnowledgeIndexStore | undefined): Promise<void> {
if (!store) {
return
}

View File

@@ -19,7 +19,6 @@ const {
hasAnyMaterialMock,
getItemsByBaseIdMock,
indexStoreCtorMock,
getPathMock,
getPathSyncMock,
deleteDirMock,
statMock
@@ -36,7 +35,6 @@ const {
hasAnyMaterialMock: vi.fn(),
getItemsByBaseIdMock: vi.fn(),
indexStoreCtorMock: vi.fn(),
getPathMock: vi.fn(),
getPathSyncMock: vi.fn(),
deleteDirMock: vi.fn(),
statMock: vi.fn()
@@ -101,7 +99,6 @@ vi.mock('@data/services/KnowledgeItemService', () => ({
}))
vi.mock('../../utils/storage/pathStorage', () => ({
getKnowledgeVectorStoreFilePath: getPathMock,
getKnowledgeVectorStoreFilePathSync: getPathSyncMock,
deleteKnowledgeBaseDir: deleteDirMock
}))
@@ -136,22 +133,24 @@ function lastStore() {
describe('KnowledgeVectorStoreService', () => {
beforeEach(() => {
vi.clearAllMocks()
getPathMock.mockImplementation(async (baseId: string) => `/tmp/${baseId}/index.sqlite`)
getPathSyncMock.mockImplementation((baseId: string) => `/tmp/${baseId}/index.sqlite`)
// Each open returns a fresh closeable driver so failure paths can assert close().
openDriverMock.mockImplementation(async () => ({
// The real driver port is synchronous (see BetterSqlite3Driver) — these mocks
// must return plain values, not promises, or the store under test would try to
// call .execute() on a Promise object.
openDriverMock.mockImplementation(() => ({
kind: 'driver',
close: vi.fn().mockResolvedValue(undefined)
close: vi.fn()
}))
createSchemaMock.mockResolvedValue(undefined)
resetSchemaMock.mockResolvedValue(undefined)
ensureIndexMetaMock.mockResolvedValue(undefined)
createSchemaMock.mockReturnValue(undefined)
resetSchemaMock.mockReturnValue(undefined)
ensureIndexMetaMock.mockReturnValue(undefined)
// Default: a fresh/blank file has no stored version → the open path takes the normal
// create branch (no rebuild). Mismatch tests override this per-case.
readIndexSchemaVersionMock.mockResolvedValue(null)
readIndexSchemaVersionMock.mockReturnValue(null)
// A non-empty material probe keeps the invisible-contents diagnostic quiet
// unless a test opts in.
hasAnyMaterialMock.mockResolvedValue(true)
hasAnyMaterialMock.mockReturnValue(true)
getItemsByBaseIdMock.mockReturnValue([])
deleteDirMock.mockResolvedValue(undefined)
indexStoreCtorMock.mockImplementation(() => ({ close: vi.fn().mockResolvedValue(undefined) }))
@@ -189,7 +188,9 @@ describe('KnowledgeVectorStoreService', () => {
it('evicts a failed open so a later call retries instead of re-awaiting the failure', async () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
openDriverMock.mockRejectedValueOnce(new Error('open failed'))
openDriverMock.mockImplementationOnce(() => {
throw new Error('open failed')
})
await expect(service.getIndexStore(base)).rejects.toThrow('open failed')
@@ -214,7 +215,7 @@ describe('KnowledgeVectorStoreService', () => {
it('creates the schema normally when the stored version matches (no rebuild)', async () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION)
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION)
await service.getIndexStore(base)
@@ -225,7 +226,7 @@ describe('KnowledgeVectorStoreService', () => {
it('rebuilds the derived index when an existing index.sqlite is at a stale schema version', async () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION - 1)
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION - 1)
await service.getIndexStore(base)
@@ -252,7 +253,7 @@ describe('KnowledgeVectorStoreService', () => {
// future refactor to `<` does not silently start mounting newer files.
const service = new KnowledgeVectorStoreService()
const base = createBase()
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION + 1)
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION + 1)
await service.getIndexStore(base)
@@ -264,11 +265,13 @@ describe('KnowledgeVectorStoreService', () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
openDriverMock.mockImplementationOnce(async () => {
openDriverMock.mockImplementationOnce(() => {
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
return openedDriver
})
ensureIndexMetaMock.mockRejectedValueOnce(new Error('belongs to a different base'))
ensureIndexMetaMock.mockImplementationOnce(() => {
throw new Error('belongs to a different base')
})
await expect(service.getIndexStore(base)).rejects.toThrow('belongs to a different base')
@@ -280,11 +283,13 @@ describe('KnowledgeVectorStoreService', () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
openDriverMock.mockImplementationOnce(async () => {
openDriverMock.mockImplementationOnce(() => {
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
return openedDriver
})
createSchemaMock.mockRejectedValueOnce(new Error('disk full'))
createSchemaMock.mockImplementationOnce(() => {
throw new Error('disk full')
})
await expect(service.getIndexStore(base)).rejects.toThrow('disk full')
@@ -345,23 +350,19 @@ describe('KnowledgeVectorStoreService', () => {
expect(indexStoreCtorMock).toHaveBeenCalledTimes(2)
})
it('deleteStore proceeds past a rejected in-flight open instead of re-throwing it', async () => {
it('deleteStore removes the directory even when no store was ever opened for the base', async () => {
// Opening a store (see openIndexStore) is fully synchronous — it either completes and
// caches a store, or throws before caching anything. There is no longer an in-flight
// state deleteStore could observe mid-open, so this covers the remaining "nothing
// cached" case: deleteStore must still close-if-present (a no-op here) and remove
// the directory.
const service = new KnowledgeVectorStoreService()
const base = createBase()
let rejectOpen: (error: Error) => void = () => {}
openDriverMock.mockImplementationOnce(() => new Promise((_, reject) => (rejectOpen = reject)))
// deleteStore grabs the still-pending open; when that open later fails, the
// delete must not inherit the open error — a store that never opened needs
// no close, and the directory removal has to go ahead.
const opening = service.getIndexStore(base)
const deleting = service.deleteStore(base.id)
await vi.waitFor(() => expect(openDriverMock).toHaveBeenCalled())
rejectOpen(new Error('open failed'))
await service.deleteStore(base.id)
await expect(opening).rejects.toThrow('open failed')
await expect(deleting).resolves.toBeUndefined()
expect(deleteDirMock).toHaveBeenCalledWith(base.id)
expect(indexStoreCtorMock).not.toHaveBeenCalled()
})
it('evicts the cached store even when directory removal fails', async () => {
@@ -452,7 +453,7 @@ describe('KnowledgeVectorStoreService', () => {
it('logs an error when an empty index mounts under a base with completed items', async () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
hasAnyMaterialMock.mockResolvedValueOnce(false)
hasAnyMaterialMock.mockReturnValueOnce(false)
getItemsByBaseIdMock.mockReturnValueOnce([
{ id: 'item-1', type: 'directory', status: 'completed' },
{ id: 'item-2', type: 'file', status: 'completed' }
@@ -470,7 +471,7 @@ describe('KnowledgeVectorStoreService', () => {
it('stays quiet when an empty index mounts under a base with no completed indexable items', async () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
hasAnyMaterialMock.mockResolvedValueOnce(false)
hasAnyMaterialMock.mockReturnValueOnce(false)
// A completed empty directory is legitimate without materials; in-flight leaves are too.
getItemsByBaseIdMock.mockReturnValueOnce([
{ id: 'item-1', type: 'directory', status: 'completed' },
@@ -486,11 +487,11 @@ describe('KnowledgeVectorStoreService', () => {
const service = new KnowledgeVectorStoreService()
const base = createBase()
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
openDriverMock.mockImplementationOnce(async () => {
openDriverMock.mockImplementationOnce(() => {
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
return openedDriver
})
hasAnyMaterialMock.mockResolvedValueOnce(false)
hasAnyMaterialMock.mockReturnValueOnce(false)
getItemsByBaseIdMock.mockImplementationOnce(() => {
throw new Error('app database unavailable')
})

View File

@@ -1,15 +1,12 @@
import { mkdirSync } from 'node:fs'
import { dirname } from 'node:path'
import { loggerService } from '@logger'
import { toAsarUnpackedPath } from '@main/utils/asar'
import Database from 'better-sqlite3'
import { getLoadablePath } from 'sqlite-vec'
import type { SqliteDriver, SqliteReclaimOutcome, SqliteTransaction, SqlQueryResult, SqlValue } from './types'
const logger = loggerService.withContext('BetterSqlite3Driver')
/**
* VACUUM in {@link BetterSqlite3Driver.reclaim} only when the freelist is BOTH a large
* fraction of the file AND past an absolute floor. The fraction skips rewriting a
@@ -44,11 +41,15 @@ function toBindable(value: SqlValue): Bindable {
* No internal write mutex: every writer (rebuildMaterial / deleteMaterials /
* reclaimSpace) runs under `KnowledgeLockManager.withBaseMutationLock(baseId)`, a
* per-base async-mutex that already serializes all writes to a base's index — so the
* driver never self-serializes. better-sqlite3's single connection makes a manual
* BEGIN IMMEDIATE transaction inherently atomic; reads run on the same connection and
* ride WAL / busy_timeout. (Holding BEGIN across a transaction callback's own
* event-loop yield is safe here precisely because the base lock blocks any other
* writer from issuing a nested BEGIN.)
* driver never self-serializes. better-sqlite3's single connection makes
* `transaction` inherently atomic; reads run on the same connection and ride WAL /
* busy_timeout. `transaction` uses better-sqlite3's native `db.transaction(fn).immediate`
* (BEGIN IMMEDIATE, matching the old libsql 'write' tx) rather than hand-rolled
* BEGIN/COMMIT/ROLLBACK: it runs `fn` synchronously to completion in one JS turn, so
* BEGIN and COMMIT can never straddle an event-loop yield — no unrelated read on this
* connection can ever observe a transaction mid-flight. It also throws `TypeError` if
* `fn` returns a Promise, turning an accidental async callback into a loud failure
* instead of a silent early commit.
*
* NOTE (future KB-owner cleanup): the driver trusts callers to hold the base lock and does not
* self-check it. A cheap `transactionActive` reentrancy assertion in {@link transaction} could
@@ -60,44 +61,30 @@ export class BetterSqlite3Driver implements SqliteDriver {
constructor(private readonly db: Database.Database) {}
async execute(sql: string, args: SqlValue[] = []): Promise<SqlQueryResult> {
execute(sql: string, args: SqlValue[] = []): SqlQueryResult {
this.assertOpen()
const stmt = this.db.prepare(sql)
const bound = args.map(toBindable)
// `reader` is true for statements that yield rows (SELECT, row-returning PRAGMA).
// run() throws on those and all() throws on non-row statements, so split on it.
if (stmt.reader) {
return { rows: stmt.all(...bound) as Array<Record<string, SqlValue>> }
return { rows: stmt.all(...bound) as Array<Record<string, SqlValue>>, changes: 0 }
}
stmt.run(...bound)
return { rows: [] }
const result = stmt.run(...bound)
return { rows: [], changes: result.changes }
}
async transaction<T>(fn: (tx: SqliteTransaction) => Promise<T>): Promise<T> {
transaction<T>(fn: (tx: SqliteTransaction) => T): T {
this.assertOpen()
const handle: SqliteTransaction = {
execute: (sql, args) => this.execute(sql, args)
}
// BEGIN IMMEDIATE acquires the write lock up front (so a read-then-write transaction
// never fails mid-way after the read already ran), matching the old libsql 'write' tx.
this.db.exec('BEGIN IMMEDIATE')
try {
const result = await fn(handle)
this.db.exec('COMMIT')
return result
} catch (error) {
// Roll back, but never let a rollback failure mask the original error that
// triggered it — that original is what callers need to diagnose the write.
try {
this.db.exec('ROLLBACK')
} catch (rollbackError) {
logger.warn('Failed to roll back knowledge index store transaction after an error', rollbackError as Error)
}
throw error
}
// .immediate acquires the write lock up front (BEGIN IMMEDIATE), so a
// read-then-write transaction never fails mid-way after the read already ran.
return this.db.transaction((tx: SqliteTransaction) => fn(tx)).immediate(handle)
}
async reclaim(preVacuumStatements: readonly string[] = []): Promise<SqliteReclaimOutcome> {
reclaim(preVacuumStatements: readonly string[] = []): SqliteReclaimOutcome {
this.assertOpen()
// Checkpoint first: frees the WAL (cheap) and folds committed frees from the
// delete into the main file so freelist_count reflects them. The base lock
@@ -144,7 +131,7 @@ export class BetterSqlite3Driver implements SqliteDriver {
}
/** Idempotent: a second close() (e.g. shutdown after an explicit deleteStore) is a no-op. */
async close(): Promise<void> {
close(): void {
if (this.closed) {
return
}
@@ -169,7 +156,7 @@ export class BetterSqlite3Driver implements SqliteDriver {
* DbService PRAGMA setup. better-sqlite3 keeps one connection, so each PRAGMA is set
* once and holds for the connection's lifetime — no replay machinery is needed.
*/
export async function openBetterSqlite3IndexDriver(filePath: string): Promise<BetterSqlite3Driver> {
export function openBetterSqlite3IndexDriver(filePath: string): BetterSqlite3Driver {
// better-sqlite3 creates the database FILE but not its parent directory (unlike libsql's
// file: URL client), so ensure the base's index dir exists before opening.
mkdirSync(dirname(filePath), { recursive: true })

View File

@@ -17,11 +17,11 @@ const RRF_K = 60
const EMBEDDING_HASH_QUERY_BATCH = 500
/**
* How long {@link KnowledgeIndexStore.deleteMaterials} may run its per-material
* row deletes before handing the main-process event loop back to the OS message
* pump (see the method doc for why). Tuned well under the multi-second window
* that surfaces the macOS beachball, while large enough that the yields add no
* measurable overhead to a small delete.
* How long {@link KnowledgeIndexStore.deleteMaterials} may run consecutive
* per-material transactions before handing the main-process event loop back to
* the OS message pump (see the method doc for why). Tuned well under the
* multi-second window that surfaces the macOS beachball, while large enough
* that the yields add no measurable overhead to a small delete.
*/
const DELETE_YIELD_BUDGET_MS = 50
@@ -70,16 +70,23 @@ export class KnowledgeIndexStore {
}
})
await this.driver.transaction(async (tx) => {
this.driver.transaction((tx) => {
// 0. Capture the material's prior content hash (undefined if it doesn't exist
// yet) so step 8 can tell whether this rebuild could possibly have orphaned
// anything, without an extra full-table scan.
const priorRow = tx.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, [materialId])
.rows[0]
const priorContentHash = priorRow === undefined ? undefined : (priorRow.current_content_hash as string | null)
// 1. Content is immutable by hash — keep the existing row if present.
await tx.execute(`INSERT OR IGNORE INTO content (content_hash, text, created_at) VALUES (?, ?, ?)`, [
tx.execute(`INSERT OR IGNORE INTO content (content_hash, text, created_at) VALUES (?, ?, ?)`, [
contentHash,
input.content.text,
now
])
// 2. Upsert the material (current_content_hash set in step 7).
await tx.execute(
tx.execute(
`INSERT INTO material (material_id, relative_path, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(material_id) DO UPDATE SET
@@ -92,12 +99,12 @@ export class KnowledgeIndexStore {
// FK to search_unit (its target_id is polymorphic), so it is deleted
// explicitly while search_unit still exists to resolve the targets; the
// FTS index is kept in sync by the search_text delete trigger.
await this.deleteMaterialSearchText(tx, materialId)
await tx.execute(`DELETE FROM search_unit WHERE material_id = ?`, [materialId])
this.deleteMaterialSearchText(tx, materialId)
const deletedUnits = tx.execute(`DELETE FROM search_unit WHERE material_id = ?`, [materialId]).changes
// 4 & 5. Insert new units and their body search_text (FTS synced by trigger).
for (const unit of units) {
await tx.execute(
tx.execute(
`INSERT INTO search_unit
(unit_id, material_id, content_hash, unit_type, unit_index, title, char_start, char_end, locator_json, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
@@ -116,7 +123,7 @@ export class KnowledgeIndexStore {
now
]
)
await tx.execute(
tx.execute(
`INSERT INTO search_text (search_text_id, target_type, target_id, kind, text, embedding_text_hash, created_at)
VALUES (?, 'search_unit', ?, 'body', ?, ?, ?)`,
[
@@ -131,10 +138,11 @@ export class KnowledgeIndexStore {
// 6. Insert missing embeddings; existing hashes are reused (decision A4).
for (const embedding of input.embeddings) {
await tx.execute(
`INSERT OR IGNORE INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`,
[embedding.embeddingTextHash, encodeVectorBlob(embedding.vector), now]
)
tx.execute(`INSERT OR IGNORE INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
embedding.embeddingTextHash,
encodeVectorBlob(embedding.vector),
now
])
}
// 6b. Coverage check: every unit's re-derived embedding hash must resolve to a
@@ -147,70 +155,89 @@ export class KnowledgeIndexStore {
// drop a hash it reported present before this rebuild writes, and the job
// then skips re-embedding it. Failing loud rolls back; the job's retry
// re-reads (the hash is now absent), re-embeds it, and converges.
await this.assertEmbeddingCoverage(tx, materialId, [...new Set(units.map((unit) => unit.embeddingTextHash))])
this.assertEmbeddingCoverage(tx, materialId, [...new Set(units.map((unit) => unit.embeddingTextHash))])
// 7. Mark the material's current content (failure/lifecycle state is the
// authority of knowledge_item, not this derived index).
await tx.execute(`UPDATE material SET current_content_hash = ?, updated_at = ? WHERE material_id = ?`, [
tx.execute(`UPDATE material SET current_content_hash = ?, updated_at = ? WHERE material_id = ?`, [
contentHash,
now,
materialId
])
// 8. Sweep rows this rebuild orphaned (old units' embeddings, old content the
// new revision no longer references). Safe under the base mutation lock.
await this.collectIndexGarbage(tx)
// new revision no longer references) — but only when it actually could have
// orphaned something. A first-time create (no prior row) or a rebuild that
// replaced zero old units AND kept the same content hash touches nothing an
// earlier revision referenced, so the GC's full-table anti-join scans would
// find nothing; skipping them turns a bulk index of K materials from
// O(K × table) into O(K). Checking unit deletions alone is not sound — a
// material that previously had zero units but a different content hash would
// slip through and leave that old content row an orphan — so both conditions
// are required.
const contentChanged = priorContentHash !== undefined && priorContentHash !== contentHash
if (deletedUnits > 0 || contentChanged) {
this.collectIndexGarbage(tx)
}
})
}
/**
* Delete many materials in ONE transaction, sweeping orphaned `embedding` /
* `content` rows with a SINGLE {@link collectIndexGarbage} pass at the end.
* Delete many materials — each in its OWN short transaction — then sweep
* orphaned `embedding` / `content` rows with a SINGLE {@link collectIndexGarbage}
* pass in a final transaction.
*
* Removing each material row cascades to its `search_unit`; the units' body
* `search_text` is deleted explicitly first (no FK), which also clears the FTS
* index via the delete trigger.
*
* collectIndexGarbage runs two FULL-TABLE anti-join scans, so calling it once
* per material (the old per-material delete loop) made a bulk delete
* per material (an old per-material delete+GC loop) made a bulk delete
* O(materials × table): deleting a folder of N files scanned the whole
* `embedding`/`content` table N times. With a large index (e.g. a folder of
* PDFs chunked into tens of thousands of rows) that blocked the main-process
* event loop for seconds — the folder-delete UI freeze. Deleting the rows up
* front and GCing once makes it O(N + table), and the single transaction makes
* the bulk delete atomic (a failure rolls the whole batch back so a retry
* re-discovers every affected id).
* front and GCing once makes it O(N + table).
*
* Batching the GC removes the super-linear cost, but the per-material row
* deletes are still linear in chunks: each `search_text` delete fires the FTS
* delete trigger, which the driver runs synchronously on the main process.
* Tens of thousands of rows still sum to a multi-second block, and
* because Electron drives the window from this same loop that block IS the
* macOS beachball (the renderer thread never stalls). So the loop yields to the
* OS message pump whenever it has run for {@link DELETE_YIELD_BUDGET_MS}: the
* total work is unchanged, but no single uninterrupted block is long enough to
* freeze the window. Yielding mid-transaction is safe — the caller holds the
* base mutation lock, so no other writer is waiting on this base's index.
* Tens of thousands of rows still sum to a multi-second block, and because
* Electron drives the window from this same loop that block IS the macOS
* beachball (the renderer thread never stalls). A driver transaction must run
* fully synchronously (no event-loop yield inside `BEGIN`..`COMMIT` — see
* {@link SqliteDriver.transaction}), so each material gets its own transaction
* and the loop yields to the OS message pump BETWEEN them whenever it has run
* for {@link DELETE_YIELD_BUDGET_MS}: the total work is unchanged, but no single
* uninterrupted block is long enough to freeze the window.
*
* This is no longer one all-or-nothing batch — a failure partway leaves the
* materials deleted so far committed. That is safe: every caller (subtreePurge.ts)
* deletes vectors before the corresponding `knowledge_item` DB rows, so those rows
* still exist after a partial failure and a retry re-discovers exactly the
* materials still left (re-deleting an already-gone one is a harmless no-op).
*/
async deleteMaterials(materialIds: string[]): Promise<void> {
const uniqueMaterialIds = [...new Set(materialIds)]
if (uniqueMaterialIds.length === 0) {
return
}
await this.driver.transaction(async (tx) => {
// performance.now() is monotonic — a wall-clock step (NTP/manual) mid-batch
// must not make the delta negative and silently disable the yields for the
// rest of a large delete, reintroducing the freeze this loop prevents.
let lastYieldAt = performance.now()
for (const materialId of uniqueMaterialIds) {
await this.deleteMaterialSearchText(tx, materialId)
await tx.execute(`DELETE FROM material WHERE material_id = ?`, [materialId])
if (performance.now() - lastYieldAt >= DELETE_YIELD_BUDGET_MS) {
await new Promise<void>((resolve) => setImmediate(resolve))
lastYieldAt = performance.now()
}
// performance.now() is monotonic — a wall-clock step (NTP/manual) mid-batch
// must not make the delta negative and silently disable the yields for the
// rest of a large delete, reintroducing the freeze this loop prevents.
let lastYieldAt = performance.now()
for (const materialId of uniqueMaterialIds) {
this.driver.transaction((tx) => {
this.deleteMaterialSearchText(tx, materialId)
tx.execute(`DELETE FROM material WHERE material_id = ?`, [materialId])
})
if (performance.now() - lastYieldAt >= DELETE_YIELD_BUDGET_MS) {
await new Promise<void>((resolve) => setImmediate(resolve))
lastYieldAt = performance.now()
}
await this.collectIndexGarbage(tx)
}
this.driver.transaction((tx) => {
this.collectIndexGarbage(tx)
})
}
@@ -224,12 +251,12 @@ export class KnowledgeIndexStore {
* `search_unit.content_hash` (FK CASCADE) reference it — both referrers are
* excluded, so the delete never violates either constraint.
*/
private async collectIndexGarbage(tx: SqliteTransaction): Promise<void> {
await tx.execute(
private collectIndexGarbage(tx: SqliteTransaction): void {
tx.execute(
`DELETE FROM embedding
WHERE NOT EXISTS (SELECT 1 FROM search_text st WHERE st.embedding_text_hash = embedding.embedding_text_hash)`
)
await tx.execute(
tx.execute(
`DELETE FROM content
WHERE NOT EXISTS (SELECT 1 FROM material m WHERE m.current_content_hash = content.content_hash)
AND NOT EXISTS (SELECT 1 FROM search_unit su WHERE su.content_hash = content.content_hash)`
@@ -256,7 +283,7 @@ export class KnowledgeIndexStore {
for (let i = 0; i < hashes.length; i += EMBEDDING_HASH_QUERY_BATCH) {
const batch = hashes.slice(i, i + EMBEDDING_HASH_QUERY_BATCH)
const placeholders = batch.map(() => '?').join(', ')
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT embedding_text_hash FROM embedding WHERE embedding_text_hash IN (${placeholders})`,
batch
)
@@ -269,7 +296,7 @@ export class KnowledgeIndexStore {
/** Read back a material's units (with body text), ordered by unit index. */
async listMaterialUnits(materialId: string): Promise<KnowledgeSearchUnit[]> {
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT su.unit_id, su.material_id, su.unit_type, su.unit_index, su.title, su.char_start, su.char_end, st.text AS body
FROM search_unit su
LEFT JOIN search_text st
@@ -307,10 +334,9 @@ export class KnowledgeIndexStore {
* the resolved material against the visible knowledge_item before reading.
*/
async getMaterialByRelativePath(relativePath: string): Promise<KnowledgeMaterialRef | null> {
const result = await this.driver.execute(
`SELECT material_id, relative_path FROM material WHERE relative_path = ?`,
[relativePath]
)
const result = this.driver.execute(`SELECT material_id, relative_path FROM material WHERE relative_path = ?`, [
relativePath
])
const row = result.rows[0]
if (!row) {
return null
@@ -329,7 +355,7 @@ export class KnowledgeIndexStore {
* holds (the same invariant rebuildMaterial enforces at write time).
*/
async readMaterialContent(materialId: string): Promise<string | null> {
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT c.text AS text
FROM material m
JOIN content c ON c.content_hash = m.current_content_hash
@@ -360,10 +386,10 @@ export class KnowledgeIndexStore {
const alpha = input.alpha ?? 0.5
const prefetch = input.topK * 5
const [vector, bm25] = await Promise.all([
this.vectorSearch(this.requireQueryEmbedding(input), prefetch),
this.bm25Search(input.queryText, prefetch)
])
// Both lanes are synchronous SQL over the same connection — sequential, not
// Promise.all: the driver has no true concurrency to parallelize.
const vector = this.vectorSearch(this.requireQueryEmbedding(input), prefetch)
const bm25 = this.bm25Search(input.queryText, prefetch)
return fuseWithRrf(vector, bm25, alpha, input.topK)
}
@@ -385,7 +411,7 @@ export class KnowledgeIndexStore {
}
async close(): Promise<void> {
await this.driver.close()
this.driver.close()
}
/** Whether the backing driver has been closed (see {@link SqliteDriver.isClosed}). */
@@ -401,14 +427,14 @@ export class KnowledgeIndexStore {
}
/** Brute-force cosine scan over the plain-BLOB embedding column (no ANN index). */
private async vectorSearch(queryEmbedding: number[], topK: number): Promise<KnowledgeIndexSearchMatch[]> {
private vectorSearch(queryEmbedding: number[], topK: number): KnowledgeIndexSearchMatch[] {
// Invariant, not a check: a base's embedding model and dimensions are immutable
// (changing them means migrating to a new base), so `queryEmbedding` and every
// stored `vector_blob` share one dimension — cosine never compares mismatched lengths.
// `WHERE dist IS NOT NULL` drops degenerate (zero-norm) vectors: cosine distance is
// undefined for them, and SQLite coerces that NULL/NaN to NULL — which would otherwise
// sort first under `ORDER BY dist` and score `1 - Number(null) = 1`, outranking real hits.
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body,
${this.vectorIndex.buildDistanceExpression('e.vector_blob')} AS dist
FROM embedding e
@@ -423,7 +449,7 @@ export class KnowledgeIndexStore {
return result.rows.map((row) => toMatch(row, 1 - Number(row.dist)))
}
private async bm25Search(queryText: string, topK: number): Promise<KnowledgeIndexSearchMatch[]> {
private bm25Search(queryText: string, topK: number): KnowledgeIndexSearchMatch[] {
// Short tokens (notably 12 char CJK words) produce no trigram, so MATCH would
// silently return nothing — route those queries to the LIKE fallback instead.
if (needsLikeFallback(queryText)) {
@@ -433,7 +459,7 @@ export class KnowledgeIndexStore {
if (!matchQuery) {
return []
}
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body, bm25(search_text_fts) AS score
FROM search_text_fts
JOIN search_text st
@@ -455,13 +481,13 @@ export class KnowledgeIndexStore {
* unit fully about the term) ranks first — and expose it as a higher-is-better
* score so it fuses sanely with the vector lane in hybrid mode.
*/
private async bm25LikeSearch(tokens: string[], topK: number): Promise<KnowledgeIndexSearchMatch[]> {
private bm25LikeSearch(tokens: string[], topK: number): KnowledgeIndexSearchMatch[] {
if (tokens.length === 0) {
return []
}
const likeClauses = tokens.map(() => `st.text LIKE ? ESCAPE '\\'`).join(' AND ')
const args: SqlValue[] = [...tokens.map(toFtsLikePattern), topK]
const result = await this.driver.execute(
const result = this.driver.execute(
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body, length(st.text) AS len
FROM search_text st
JOIN search_unit su ON su.unit_id = st.target_id
@@ -475,12 +501,12 @@ export class KnowledgeIndexStore {
}
/** Throw (rolling back the surrounding rebuild) if any unit hash has no embedding row. */
private async assertEmbeddingCoverage(tx: SqliteTransaction, materialId: string, hashes: string[]): Promise<void> {
private assertEmbeddingCoverage(tx: SqliteTransaction, materialId: string, hashes: string[]): void {
const missing = new Set(hashes)
for (let i = 0; i < hashes.length; i += EMBEDDING_HASH_QUERY_BATCH) {
const batch = hashes.slice(i, i + EMBEDDING_HASH_QUERY_BATCH)
const placeholders = batch.map(() => '?').join(', ')
const result = await tx.execute(
const result = tx.execute(
`SELECT embedding_text_hash FROM embedding WHERE embedding_text_hash IN (${placeholders})`,
batch
)
@@ -495,8 +521,8 @@ export class KnowledgeIndexStore {
}
}
private async deleteMaterialSearchText(tx: SqliteTransaction, materialId: string): Promise<void> {
await tx.execute(
private deleteMaterialSearchText(tx: SqliteTransaction, materialId: string): void {
tx.execute(
`DELETE FROM search_text
WHERE target_type = 'search_unit'
AND target_id IN (SELECT unit_id FROM search_unit WHERE material_id = ?)`,

View File

@@ -2,127 +2,114 @@ import { mkdtempSync, rmSync } from 'node:fs'
import { tmpdir } from 'node:os'
import { join } from 'node:path'
import type Database from 'better-sqlite3'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
import { BetterSqlite3Driver, openBetterSqlite3IndexDriver } from '../BetterSqlite3Driver'
const loggerWarnMock = vi.hoisted(() => vi.fn())
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({ warn: loggerWarnMock })
}
}))
import type { BetterSqlite3Driver } from '../BetterSqlite3Driver'
import { openBetterSqlite3IndexDriver } from '../BetterSqlite3Driver'
describe('BetterSqlite3Driver', () => {
let tempDir: string
let driver: BetterSqlite3Driver
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-driver-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await driver.execute('CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)')
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
driver.execute('CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)')
})
afterEach(async () => {
await driver.close()
afterEach(() => {
driver.close()
rmSync(tempDir, { recursive: true, force: true })
})
it('enables foreign keys on open', async () => {
const result = await driver.execute('PRAGMA foreign_keys')
it('enables foreign keys on open', () => {
const result = driver.execute('PRAGMA foreign_keys')
expect(result.rows[0].foreign_keys).toBe(1)
})
it('opens in WAL journal mode with a busy timeout so reads survive a concurrent write', async () => {
const journal = await driver.execute('PRAGMA journal_mode')
it('opens in WAL journal mode with a busy timeout so reads survive a concurrent write', () => {
const journal = driver.execute('PRAGMA journal_mode')
expect(String(journal.rows[0].journal_mode).toLowerCase()).toBe('wal')
const timeout = await driver.execute('PRAGMA busy_timeout')
const timeout = driver.execute('PRAGMA busy_timeout')
expect(Number(timeout.rows[0].timeout)).toBeGreaterThan(0)
})
it('maps rows to plain objects', async () => {
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
it('maps rows to plain objects', () => {
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
const select = await driver.execute('SELECT id, v FROM t WHERE id = ?', [1])
const select = driver.execute('SELECT id, v FROM t WHERE id = ?', [1])
expect(select.rows).toEqual([{ id: 1, v: 'a' }])
})
it('commits a successful transaction', async () => {
await driver.transaction(async (tx) => {
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'y'])
it('reports rows changed by a write statement', () => {
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
const result = driver.execute('DELETE FROM t WHERE id = ?', [1])
expect(result.changes).toBe(1)
const select = driver.execute('SELECT id FROM t')
expect(select.changes).toBe(0)
})
it('commits a successful transaction', () => {
driver.transaction((tx) => {
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'y'])
})
const count = await driver.execute('SELECT COUNT(*) AS n FROM t')
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
expect(count.rows[0].n).toBe(2)
})
it('rolls back a failed transaction', async () => {
await expect(
driver.transaction(async (tx) => {
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
it('rolls back a failed transaction', () => {
expect(() =>
driver.transaction((tx) => {
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
throw new Error('boom')
})
).rejects.toThrow('boom')
).toThrow('boom')
const count = await driver.execute('SELECT COUNT(*) AS n FROM t')
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
expect(count.rows[0].n).toBe(0)
})
it('rethrows the original error when rollback also fails, instead of masking it', async () => {
const originalError = new Error('insert failed')
const rollbackError = new Error('rollback failed')
// A fake better-sqlite3 connection: the bracket's BEGIN IMMEDIATE succeeds, the
// body's statement fails with originalError, then the ROLLBACK that the catch
// issues fails with rollbackError. The driver must surface originalError (what the
// caller needs to diagnose the write) and only log the rollback failure.
const fakeDb = {
prepare: () => {
throw originalError
},
exec: (sql: string) => {
if (sql === 'ROLLBACK') {
throw rollbackError
}
},
pragma: () => undefined,
close: () => undefined
} as unknown as Database.Database
const isolatedDriver = new BetterSqlite3Driver(fakeDb)
it('throws if the transaction callback returns a promise, instead of silently committing early', () => {
// better-sqlite3's native transaction() rejects an async callback outright (see
// BetterSqlite3Driver.transaction doc) — an accidental async fn must fail loud
// rather than commit before its awaited work actually ran.
expect(() =>
driver.transaction((tx) => {
return Promise.resolve(tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x']))
})
).toThrow(/promise/i)
await expect(isolatedDriver.transaction(async (tx) => tx.execute('INSERT INTO t (id) VALUES (1)'))).rejects.toBe(
originalError
)
expect(loggerWarnMock).toHaveBeenCalledWith(
'Failed to roll back knowledge index store transaction after an error',
rollbackError
)
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
expect(count.rows[0].n).toBe(0)
})
it('checkpoints but skips the VACUUM when the freed space is below the reclaim threshold', async () => {
it('checkpoints but skips the VACUUM when the freed space is below the reclaim threshold', () => {
// A small delete leaves a freelist far below the size/ratio thresholds, so reclaim
// only truncates the WAL and reports that no whole-file rewrite ran.
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
await driver.execute('DELETE FROM t')
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
driver.execute('DELETE FROM t')
const outcome = await driver.reclaim()
const outcome = driver.reclaim()
expect(outcome).toEqual({ vacuumed: false, reclaimedBytes: 0 })
})
it('reports closed state and rejects use after close with a deterministic error', async () => {
it('reports closed state and rejects use after close with a deterministic error', () => {
expect(driver.isClosed()).toBe(false)
await driver.close()
driver.close()
expect(driver.isClosed()).toBe(true)
await expect(driver.execute('SELECT 1')).rejects.toThrow(/closed/)
await expect(driver.transaction(async (tx) => tx.execute('SELECT 1'))).rejects.toThrow(/closed/)
expect(() => driver.execute('SELECT 1')).toThrow(/closed/)
expect(() => driver.transaction((tx) => tx.execute('SELECT 1'))).toThrow(/closed/)
// A second close (e.g. app shutdown after an explicit deleteStore) is a no-op.
await expect(driver.close()).resolves.toBeUndefined()
expect(driver.close()).toBeUndefined()
})
})

View File

@@ -14,14 +14,14 @@ describe('BetterSqlite3VectorIndex', () => {
let driver: BetterSqlite3Driver
const vectorIndex = new BetterSqlite3VectorIndex()
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-vindex-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
})
afterEach(async () => {
await driver.close()
afterEach(() => {
driver.close()
rmSync(tempDir, { recursive: true, force: true })
})
@@ -39,23 +39,23 @@ describe('BetterSqlite3VectorIndex', () => {
[vectorIndex.bindQueryVector(query), k]
)
it('brute-force ranks nearest vectors first over a plain-BLOB column', async () => {
await insertEmbedding('near', [1, 0, 0])
await insertEmbedding('mid', [0.7, 0.7, 0])
await insertEmbedding('far', [0, 0, 1])
it('brute-force ranks nearest vectors first over a plain-BLOB column', () => {
insertEmbedding('near', [1, 0, 0])
insertEmbedding('mid', [0.7, 0.7, 0])
insertEmbedding('far', [0, 0, 1])
const result = await topK([1, 0, 0], 3)
const result = topK([1, 0, 0], 3)
expect(result.rows.map((row) => row.h)).toEqual(['near', 'mid', 'far'])
expect(result.rows[0].dist as number).toBeLessThan(0.001)
expect(result.rows[2].dist as number).toBeGreaterThan(0.9)
})
it('respects the LIMIT k bound', async () => {
await insertEmbedding('a', [1, 0, 0])
await insertEmbedding('b', [0, 1, 0])
it('respects the LIMIT k bound', () => {
insertEmbedding('a', [1, 0, 0])
insertEmbedding('b', [0, 1, 0])
const result = await topK([1, 0, 0], 1)
const result = topK([1, 0, 0], 1)
expect(result.rows).toHaveLength(1)
expect(result.rows[0].h).toBe('a')

View File

@@ -24,10 +24,10 @@ describe('KnowledgeIndexStore integration (real better-sqlite3)', () => {
let tempDir: string
let store: KnowledgeIndexStore
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-integration-'))
const driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
const driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
})

View File

@@ -33,21 +33,21 @@ const bm25Row = (unitId: string, materialId: string, score: number): Record<stri
*/
function createFakeDriver(vectorRows: Array<Record<string, SqlValue>>, bm25Rows: Array<Record<string, SqlValue>>) {
const limits: { vector?: number; bm25?: number } = {}
const execute = vi.fn(async (sql: string, args: SqlValue[] = []): Promise<SqlQueryResult> => {
const execute = vi.fn((sql: string, args: SqlValue[] = []): SqlQueryResult => {
const limit = Number(args[args.length - 1])
if (sql.includes('search_text_fts MATCH')) {
limits.bm25 = limit
return { rows: bm25Rows }
return { rows: bm25Rows, changes: 0 }
}
limits.vector = limit
return { rows: vectorRows }
return { rows: vectorRows, changes: 0 }
})
const driver: SqliteDriver = {
execute,
transaction: async (fn) => fn({ execute }),
reclaim: async () => ({ vacuumed: false, reclaimedBytes: 0 }),
transaction: (fn) => fn({ execute }),
reclaim: () => ({ vacuumed: false, reclaimedBytes: 0 }),
isClosed: () => false,
close: async () => undefined
close: () => undefined
}
return { driver, limits }
}

View File

@@ -16,10 +16,10 @@ describe('KnowledgeIndexStore.search', () => {
let driver: BetterSqlite3Driver
let store: KnowledgeIndexStore
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-search-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
})

View File

@@ -38,10 +38,10 @@ describe('KnowledgeIndexStore', () => {
let driver: BetterSqlite3Driver
let store: KnowledgeIndexStore
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-store-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
})
@@ -50,16 +50,14 @@ describe('KnowledgeIndexStore', () => {
rmSync(tempDir, { recursive: true, force: true })
})
const count = async (table: string) => Number((await driver.execute(`SELECT COUNT(*) AS n FROM ${table}`)).rows[0].n)
const count = async (table: string) => Number(driver.execute(`SELECT COUNT(*) AS n FROM ${table}`).rows[0].n)
const ftsMatchCount = async (term: string) =>
(
await driver.execute(
`SELECT st.search_text_id AS id
FROM search_text_fts JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
WHERE search_text_fts MATCH ?`,
[term]
)
driver.execute(
`SELECT st.search_text_id AS id
FROM search_text_fts JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
WHERE search_text_fts MATCH ?`,
[term]
).rows.length
// The reliable external-content desync detector: the default `integrity-check` does NOT compare
@@ -86,7 +84,7 @@ describe('KnowledgeIndexStore', () => {
expect(await count('search_unit')).toBe(2)
expect(await count('search_text')).toBe(2)
const material = await driver.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, ['m1'])
const material = driver.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, ['m1'])
expect(material.rows[0].current_content_hash).not.toBeNull()
})
@@ -112,7 +110,7 @@ describe('KnowledgeIndexStore', () => {
// Corrupt the store: drop the body row out from under the unit. The same
// damage silently excludes the unit from search (INNER JOIN); the list lane
// must fail loudly rather than add a third symptom (existing-but-empty chunk).
await driver.execute(`DELETE FROM search_text WHERE target_type = 'search_unit' AND kind = 'body'`)
driver.execute(`DELETE FROM search_text WHERE target_type = 'search_unit' AND kind = 'body'`)
await expect(store.listMaterialUnits('m1')).rejects.toThrow('missing the body text for unit')
})
@@ -321,7 +319,7 @@ describe('KnowledgeIndexStore', () => {
// single GC pass — the path a folder delete takes (one deleteMaterials over N files).
await store.deleteMaterials(['m1', 'm2', 'm1'])
expect((await driver.execute(`SELECT material_id FROM material`)).rows.map((r) => r.material_id)).toEqual(['m3'])
expect(driver.execute(`SELECT material_id FROM material`).rows.map((r) => r.material_id)).toEqual(['m3'])
expect(await store.listMaterialUnits('m1')).toEqual([])
expect(await store.listMaterialUnits('m2')).toEqual([])
// The single end-of-batch GC must sweep m2's now-orphaned body/embedding/content while
@@ -417,7 +415,7 @@ describe('KnowledgeIndexStore', () => {
expect(outcome.vacuumed).toBe(true)
expect(outcome.reclaimedBytes).toBeGreaterThan(0)
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
expect(ftsIntegrityCheck()).toBeDefined()
// The survivor stays both keyword- and vector-reachable after the rewrite.
expect(await ftsMatchCount('knowledge')).toBe(1)
expect((await store.search({ queryText: 'knowledge', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual([
@@ -466,7 +464,7 @@ describe('KnowledgeIndexStore', () => {
// VACUUM renumbers implicit rowids; assert the external-content FTS did NOT desync. Keyed on the
// stable fts_rowid it stays aligned by construction — verified with the reliable integrity check
// (the default integrity-check would pass even on a desync).
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
expect(ftsIntegrityCheck()).toBeDefined()
})
it('keeps search_text_fts aligned after a rowid-reshuffling rebuild (fts_rowid keying)', async () => {
@@ -485,23 +483,23 @@ describe('KnowledgeIndexStore', () => {
// Reshuffle the implicit rowid exactly as a table rebuild / VACUUM does: copy the table (new
// contiguous rowids, dropping the hole) while carrying fts_rowid verbatim, then re-assert the
// dropped indexes + triggers. The FTS vtable is untouched, so it stays keyed by fts_rowid.
await driver.execute('PRAGMA foreign_keys=OFF')
await driver.execute('CREATE TABLE __new_search_text AS SELECT * FROM search_text')
await driver.execute('DROP TABLE search_text')
await driver.execute('ALTER TABLE __new_search_text RENAME TO search_text')
await driver.execute('PRAGMA foreign_keys=ON')
await createKnowledgeIndexSchema(driver)
driver.execute('PRAGMA foreign_keys=OFF')
driver.execute('CREATE TABLE __new_search_text AS SELECT * FROM search_text')
driver.execute('DROP TABLE search_text')
driver.execute('ALTER TABLE __new_search_text RENAME TO search_text')
driver.execute('PRAGMA foreign_keys=ON')
createKnowledgeIndexSchema(driver)
// fts_rowid was copied through the rebuild, so the index stays aligned: the reliable check passes
// and the survivors resolve to the right materials (a rowid-keyed index would return wrong/empty).
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
expect(ftsIntegrityCheck()).toBeDefined()
expect((await store.search({ queryText: 'cherry', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual([
'm3'
])
expect((await store.search({ queryText: 'date', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual(['m4'])
// The reshuffle reassigned implicit rowids while leaving fts_rowid untouched and unique.
expect(await count('search_text')).toBe(3)
expect(Number((await driver.execute(`SELECT COUNT(DISTINCT fts_rowid) AS n FROM search_text`)).rows[0].n)).toBe(3)
expect(Number(driver.execute(`SELECT COUNT(DISTINCT fts_rowid) AS n FROM search_text`).rows[0].n)).toBe(3)
})
it('integrity-check,1 catches a NULL fts_rowid desync (and proves the detector is live)', async () => {
@@ -513,8 +511,8 @@ describe('KnowledgeIndexStore', () => {
// now references a key the content row no longer carries. This both guards that hazard AND is the
// positive control that ftsIntegrityCheck() really throws on a desync — every other use of it in
// this suite asserts it resolves, so without this case those assertions could pass vacuously.
await driver.execute(`UPDATE search_text SET fts_rowid = NULL`)
await expect(ftsIntegrityCheck()).rejects.toThrow()
driver.execute(`UPDATE search_text SET fts_rowid = NULL`)
expect(() => ftsIntegrityCheck()).toThrow()
})
it('reuses a freed fts_rowid without leaving a stale FTS entry', async () => {
@@ -527,14 +525,14 @@ describe('KnowledgeIndexStore', () => {
await store.rebuildMaterial('m3', buildInput('charlie cherry body', [[0, 19]], 'c.md'))
// m3 holds the current MAX fts_rowid. Delete it to free that rowid.
const maxBefore = Number((await driver.execute(`SELECT MAX(fts_rowid) AS m FROM search_text`)).rows[0].m)
const maxBefore = Number(driver.execute(`SELECT MAX(fts_rowid) AS m FROM search_text`).rows[0].m)
await store.deleteMaterials(['m3'])
expect(await ftsMatchCount('cherry')).toBe(0)
// The next insert's MAX(fts_rowid)+1 reuses the value m3 just vacated.
await store.rebuildMaterial('m4', buildInput('delta dragon body', [[0, 17]], 'd.md'))
const reused = Number(
(await driver.execute(`SELECT fts_rowid FROM search_text WHERE text LIKE 'delta%'`)).rows[0].fts_rowid
driver.execute(`SELECT fts_rowid FROM search_text WHERE text LIKE 'delta%'`).rows[0].fts_rowid
)
expect(reused).toBe(maxBefore) // landed on the exact physical rowid m3 had held
@@ -544,7 +542,7 @@ describe('KnowledgeIndexStore', () => {
'm4'
])
expect(await ftsMatchCount('cherry')).toBe(0)
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
expect(ftsIntegrityCheck()).toBeDefined()
})
it('keeps a shared embedding reachable for the other material after rebuilding the one that introduced it', async () => {

View File

@@ -17,21 +17,21 @@ describe('ensureIndexMeta', () => {
let tempDir: string
let driver: BetterSqlite3Driver
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-meta-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
})
afterEach(async () => {
await driver.close()
afterEach(() => {
driver.close()
rmSync(tempDir, { recursive: true, force: true })
})
it('writes the single meta identity row with the schema version and base id on first open', async () => {
await ensureIndexMeta(driver, META_INPUT)
it('writes the single meta identity row with the schema version and base id on first open', () => {
ensureIndexMeta(driver, META_INPUT)
const result = await driver.execute('SELECT * FROM meta')
const result = driver.execute('SELECT * FROM meta')
expect(result.rows).toHaveLength(1)
const row = result.rows[0]
expect(row.id).toBe(1)
@@ -39,24 +39,22 @@ describe('ensureIndexMeta', () => {
expect(row.base_id).toBe('kb-1')
})
it('is idempotent across re-opens: the original row is kept, not duplicated or rewritten', async () => {
await ensureIndexMeta(driver, META_INPUT)
const first = await driver.execute('SELECT created_at FROM meta WHERE id = 1')
it('is idempotent across re-opens: the original row is kept, not duplicated or rewritten', () => {
ensureIndexMeta(driver, META_INPUT)
const first = driver.execute('SELECT created_at FROM meta WHERE id = 1')
await ensureIndexMeta(driver, META_INPUT)
const second = await driver.execute('SELECT created_at FROM meta WHERE id = 1')
ensureIndexMeta(driver, META_INPUT)
const second = driver.execute('SELECT created_at FROM meta WHERE id = 1')
const count = await driver.execute('SELECT COUNT(*) AS n FROM meta')
const count = driver.execute('SELECT COUNT(*) AS n FROM meta')
expect(count.rows[0].n).toBe(1)
expect(second.rows[0].created_at).toBe(first.rows[0].created_at)
})
it('rejects opening an index that belongs to a different base (anti-mismount guard, §4.1)', async () => {
await ensureIndexMeta(driver, META_INPUT)
it('rejects opening an index that belongs to a different base (anti-mismount guard, §4.1)', () => {
ensureIndexMeta(driver, META_INPUT)
await expect(ensureIndexMeta(driver, { ...META_INPUT, baseId: 'kb-OTHER' })).rejects.toThrow(
/belongs to a different base/
)
expect(() => ensureIndexMeta(driver, { ...META_INPUT, baseId: 'kb-OTHER' })).toThrow(/belongs to a different base/)
})
})
@@ -67,25 +65,25 @@ describe('index content diagnostics', () => {
let tempDir: string
let driver: BetterSqlite3Driver
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-meta-'))
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
})
afterEach(async () => {
await driver.close()
afterEach(() => {
driver.close()
rmSync(tempDir, { recursive: true, force: true })
})
it('hasAnyMaterial is false on a fresh index and true once a material row exists', async () => {
expect(await hasAnyMaterial(driver)).toBe(false)
it('hasAnyMaterial is false on a fresh index and true once a material row exists', () => {
expect(hasAnyMaterial(driver)).toBe(false)
await driver.execute(
driver.execute(
`INSERT INTO material (material_id, relative_path, created_at, updated_at)
VALUES ('m1', 'doc.md', 1, 1)`
)
expect(await hasAnyMaterial(driver)).toBe(true)
expect(hasAnyMaterial(driver)).toBe(true)
})
})

View File

@@ -20,16 +20,16 @@ describe('knowledge index schema', () => {
let tempDir: string
let driver: BetterSqlite3Driver
beforeEach(async () => {
beforeEach(() => {
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-index-'))
// openBetterSqlite3IndexDriver enables foreign keys per-connection (for CASCADE)
// and loads sqlite-vec (for vec_distance_cosine).
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
await createKnowledgeIndexSchema(driver)
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
createKnowledgeIndexSchema(driver)
})
afterEach(async () => {
await driver?.close()
afterEach(() => {
driver?.close()
rmSync(tempDir, { recursive: true, force: true })
})
@@ -58,15 +58,15 @@ describe('knowledge index schema', () => {
)
/** Every schema object (tables, triggers, indexes, FTS shadow tables), as stable `type:name` keys. */
const listSchemaObjects = async () => {
const result = await driver.execute(`SELECT type, name FROM sqlite_master ORDER BY type, name`)
const listSchemaObjects = () => {
const result = driver.execute(`SELECT type, name FROM sqlite_master ORDER BY type, name`)
return result.rows.map((row) => `${row.type}:${row.name}`)
}
describe('schema creation', () => {
it('creates all 7 schema objects', async () => {
it('creates all 7 schema objects', () => {
const expected = ['meta', 'content', 'material', 'search_unit', 'search_text', 'embedding', 'search_text_fts']
const result = await driver.execute(
const result = driver.execute(
`SELECT name FROM sqlite_master WHERE name IN (${expected.map(() => '?').join(', ')})`,
expected
)
@@ -76,10 +76,10 @@ describe('knowledge index schema', () => {
}
})
it('is idempotent: re-applying through the driver leaves the object set unchanged', async () => {
const objectsBefore = await listSchemaObjects()
await expect(createKnowledgeIndexSchema(driver)).resolves.toBeUndefined()
expect(await listSchemaObjects()).toEqual(objectsBefore)
it('is idempotent: re-applying through the driver leaves the object set unchanged', () => {
const objectsBefore = listSchemaObjects()
expect(createKnowledgeIndexSchema(driver)).toBeUndefined()
expect(listSchemaObjects()).toEqual(objectsBefore)
})
it('exposes a static, parameterless statement list', () => {
@@ -98,60 +98,60 @@ describe('knowledge index schema', () => {
[id, TS, TS]
)
it('accepts the single id = 1 row', async () => {
await expect(insertMeta(1)).resolves.toBeDefined()
it('accepts the single id = 1 row', () => {
expect(insertMeta(1)).toBeDefined()
})
it('rejects id != 1', async () => {
await expect(insertMeta(2)).rejects.toThrow()
it('rejects id != 1', () => {
expect(() => insertMeta(2)).toThrow()
})
it('rejects a second row', async () => {
await insertMeta(1)
await expect(insertMeta(1)).rejects.toThrow()
it('rejects a second row', () => {
insertMeta(1)
expect(() => insertMeta(1)).toThrow()
})
})
describe('material constraints', () => {
it('accepts a valid material', async () => {
await expect(insertMaterial('m1', 'docs/paper.md')).resolves.toBeDefined()
it('accepts a valid material', () => {
expect(insertMaterial('m1', 'docs/paper.md')).toBeDefined()
})
it('rejects an absolute relative_path', async () => {
await expect(insertMaterial('m1', '/abs/paper.md')).rejects.toThrow()
it('rejects an absolute relative_path', () => {
expect(() => insertMaterial('m1', '/abs/paper.md')).toThrow()
})
it('rejects a reserved .cherry relative_path', async () => {
await expect(insertMaterial('m1', '.cherry/index.sqlite')).rejects.toThrow()
it('rejects a reserved .cherry relative_path', () => {
expect(() => insertMaterial('m1', '.cherry/index.sqlite')).toThrow()
})
it('enforces unique relative_path', async () => {
await insertMaterial('m1', 'a.md')
await expect(insertMaterial('m2', 'a.md')).rejects.toThrow()
it('enforces unique relative_path', () => {
insertMaterial('m1', 'a.md')
expect(() => insertMaterial('m2', 'a.md')).toThrow()
})
})
describe('foreign keys', () => {
it('cascades search_unit deletion when its material is deleted', async () => {
await insertContent('h1', 'hello')
await insertMaterial('m1', 'a.md', 'h1')
await insertSearchUnit('u1', 'm1', 'h1')
it('cascades search_unit deletion when its material is deleted', () => {
insertContent('h1', 'hello')
insertMaterial('m1', 'a.md', 'h1')
insertSearchUnit('u1', 'm1', 'h1')
await driver.execute(`DELETE FROM material WHERE material_id = ?`, ['m1'])
driver.execute(`DELETE FROM material WHERE material_id = ?`, ['m1'])
const remaining = await driver.execute(`SELECT COUNT(*) AS n FROM search_unit`)
const remaining = driver.execute(`SELECT COUNT(*) AS n FROM search_unit`)
expect(remaining.rows[0].n).toBe(0)
})
it('rejects a search_unit referencing a missing material', async () => {
await insertContent('h1', 'hello')
await expect(insertSearchUnit('u1', 'missing-material', 'h1')).rejects.toThrow()
it('rejects a search_unit referencing a missing material', () => {
insertContent('h1', 'hello')
expect(() => insertSearchUnit('u1', 'missing-material', 'h1')).toThrow()
})
})
describe('FTS5 (trigram, external content)', () => {
const matchBody = async (term: string) => {
const result = await driver.execute(
const matchBody = (term: string) => {
const result = driver.execute(
`SELECT st.search_text_id AS id
FROM search_text_fts
JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
@@ -161,52 +161,50 @@ describe('knowledge index schema', () => {
return result.rows.map((row) => row.id as string)
}
beforeEach(async () => {
await insertContent('h1', 'body content')
await insertMaterial('m1', 'a.md', 'h1')
await insertSearchUnit('u1', 'm1', 'h1')
beforeEach(() => {
insertContent('h1', 'body content')
insertMaterial('m1', 'a.md', 'h1')
insertSearchUnit('u1', 'm1', 'h1')
})
it('indexes inserted search_text and matches by term', async () => {
await insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
expect(await matchBody('knowledge')).toEqual(['st1'])
it('indexes inserted search_text and matches by term', () => {
insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
expect(matchBody('knowledge')).toEqual(['st1'])
})
it('removes the FTS entry when search_text is deleted (ad trigger)', async () => {
await insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
expect(await matchBody('knowledge')).toEqual(['st1'])
it('removes the FTS entry when search_text is deleted (ad trigger)', () => {
insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
expect(matchBody('knowledge')).toEqual(['st1'])
await driver.execute(`DELETE FROM search_text WHERE search_text_id = ?`, ['st1'])
expect(await matchBody('knowledge')).toEqual([])
driver.execute(`DELETE FROM search_text WHERE search_text_id = ?`, ['st1'])
expect(matchBody('knowledge')).toEqual([])
})
it('re-syncs the FTS entry when search_text.text is updated (au trigger)', async () => {
it('re-syncs the FTS entry when search_text.text is updated (au trigger)', () => {
// Production rebuilds are delete + insert, so this UPDATE path has no caller
// today; the trigger is kept defensively and this test pins its behavior.
await insertSearchText('st1', 'u1', 'alpha knowledge base', 'eh1')
expect(await matchBody('knowledge')).toEqual(['st1'])
const ftsRowidBefore = (
await driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
).rows[0].fts_rowid
insertSearchText('st1', 'u1', 'alpha knowledge base', 'eh1')
expect(matchBody('knowledge')).toEqual(['st1'])
const ftsRowidBefore = driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
.rows[0].fts_rowid
await driver.execute(`UPDATE search_text SET text = ? WHERE search_text_id = ?`, ['beta wisdom corpus', 'st1'])
driver.execute(`UPDATE search_text SET text = ? WHERE search_text_id = ?`, ['beta wisdom corpus', 'st1'])
expect(await matchBody('knowledge')).toEqual([])
expect(await matchBody('wisdom')).toEqual(['st1'])
expect(matchBody('knowledge')).toEqual([])
expect(matchBody('wisdom')).toEqual(['st1'])
// fts_rowid is stable across a text edit (the au trigger re-keys the FTS row by NEW.fts_rowid,
// it does not reassign it) — so the external-content index stays aligned.
const ftsRowidAfter = (
await driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
).rows[0].fts_rowid
const ftsRowidAfter = driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
.rows[0].fts_rowid
expect(ftsRowidAfter).toBe(ftsRowidBefore)
await expect(
expect(
driver.execute(`INSERT INTO search_text_fts(search_text_fts, rank) VALUES('integrity-check', 1)`)
).resolves.toBeDefined()
).toBeDefined()
})
it('exposes a bm25 rank for matches', async () => {
await insertSearchText('st1', 'u1', 'knowledge retrieval', 'eh1')
const result = await driver.execute(
it('exposes a bm25 rank for matches', () => {
insertSearchText('st1', 'u1', 'knowledge retrieval', 'eh1')
const result = driver.execute(
`SELECT bm25(search_text_fts) AS score
FROM search_text_fts
WHERE search_text_fts MATCH ?`,
@@ -218,15 +216,15 @@ describe('knowledge index schema', () => {
})
describe('embedding vector (engine-portability spike, §5.6)', () => {
it('computes vec_distance_cosine directly over a plain BLOB column', async () => {
it('computes vec_distance_cosine directly over a plain BLOB column', () => {
const vector = [0.1, 0.2, 0.3]
await driver.execute(`INSERT INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
driver.execute(`INSERT INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
'eh_vec',
encodeVectorBlob(vector),
TS
])
const result = await driver.execute(
const result = driver.execute(
`SELECT vec_distance_cosine(vector_blob, ?) AS dist
FROM embedding
WHERE embedding_text_hash = ?`,
@@ -241,29 +239,29 @@ describe('knowledge index schema', () => {
})
describe('schema version & rebuild (open-time migration)', () => {
it('reports null until a meta row exists, then the current constant', async () => {
it('reports null until a meta row exists, then the current constant', () => {
// beforeEach created the schema (meta table exists) but no id=1 row yet.
expect(await readIndexSchemaVersion(driver)).toBeNull()
await ensureIndexMeta(driver, { baseId: 'base-1' })
expect(await readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
expect(readIndexSchemaVersion(driver)).toBeNull()
ensureIndexMeta(driver, { baseId: 'base-1' })
expect(readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
})
it('reports null on a file with no meta table at all', async () => {
const fresh = await openBetterSqlite3IndexDriver(join(tempDir, 'no-meta.sqlite'))
it('reports null on a file with no meta table at all', () => {
const fresh = openBetterSqlite3IndexDriver(join(tempDir, 'no-meta.sqlite'))
try {
expect(await readIndexSchemaVersion(fresh)).toBeNull()
expect(readIndexSchemaVersion(fresh)).toBeNull()
} finally {
await fresh.close()
fresh.close()
}
})
it('reports null for a malformed (non-numeric) schema_version cell', async () => {
await ensureIndexMeta(driver, { baseId: 'base-1' })
it('reports null for a malformed (non-numeric) schema_version cell', () => {
ensureIndexMeta(driver, { baseId: 'base-1' })
// A blanked/corrupt version (stored as text under the column's INTEGER affinity) must read as
// "unknown" → null, so the open path treats it as fresh and creates rather than silently
// mistaking it for a real version. Covers the `typeof === 'number'` guard.
await driver.execute(`UPDATE meta SET schema_version = 'corrupt'`)
expect(await readIndexSchemaVersion(driver)).toBeNull()
driver.execute(`UPDATE meta SET schema_version = 'corrupt'`)
expect(readIndexSchemaVersion(driver)).toBeNull()
})
it('pins the current schema version (a deliberate bump-me tripwire)', () => {
@@ -273,25 +271,25 @@ describe('knowledge index schema', () => {
expect(KNOWLEDGE_INDEX_SCHEMA_VERSION).toBe(2)
})
it('resetKnowledgeIndexSchema wipes data, rebuilds every object, and lets meta restamp the version', async () => {
const freshObjects = await listSchemaObjects()
it('resetKnowledgeIndexSchema wipes data, rebuilds every object, and lets meta restamp the version', () => {
const freshObjects = listSchemaObjects()
// Seed a populated index stamped at an older layout version.
await ensureIndexMeta(driver, { baseId: 'base-1' })
await insertContent('h1', 'hello world')
await insertMaterial('m1', 'a.md', 'h1')
await driver.execute(`UPDATE meta SET schema_version = 1`)
expect(await readIndexSchemaVersion(driver)).toBe(1)
ensureIndexMeta(driver, { baseId: 'base-1' })
insertContent('h1', 'hello world')
insertMaterial('m1', 'a.md', 'h1')
driver.execute(`UPDATE meta SET schema_version = 1`)
expect(readIndexSchemaVersion(driver)).toBe(1)
await resetKnowledgeIndexSchema(driver)
resetKnowledgeIndexSchema(driver)
// Same object set as a fresh schema, but the derived data is gone (rebuildable artifact).
expect(await listSchemaObjects()).toEqual(freshObjects)
expect((await driver.execute(`SELECT COUNT(*) AS n FROM material`)).rows[0].n).toBe(0)
expect((await driver.execute(`SELECT COUNT(*) AS n FROM content`)).rows[0].n).toBe(0)
expect(listSchemaObjects()).toEqual(freshObjects)
expect(driver.execute(`SELECT COUNT(*) AS n FROM material`).rows[0].n).toBe(0)
expect(driver.execute(`SELECT COUNT(*) AS n FROM content`).rows[0].n).toBe(0)
// The reset drops meta too, so the version is null until the open path restamps it.
expect(await readIndexSchemaVersion(driver)).toBeNull()
await ensureIndexMeta(driver, { baseId: 'base-1' })
expect(await readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
expect(readIndexSchemaVersion(driver)).toBeNull()
ensureIndexMeta(driver, { baseId: 'base-1' })
expect(readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
})
})
})

View File

@@ -20,15 +20,15 @@ export interface IndexMetaInput {
* the store-open path logs an error when that happens under a base that already
* has completed items (see KnowledgeVectorStoreService).
*/
export async function ensureIndexMeta(executor: SqliteExecutor, input: IndexMetaInput): Promise<void> {
export function ensureIndexMeta(executor: SqliteExecutor, input: IndexMetaInput): void {
const now = Date.now()
await executor.execute(
executor.execute(
`INSERT OR IGNORE INTO meta (id, schema_version, base_id, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)`,
[KNOWLEDGE_INDEX_SCHEMA_VERSION, input.baseId, now, now]
)
const stored = await executor.execute(`SELECT base_id FROM meta WHERE id = 1`)
const stored = executor.execute(`SELECT base_id FROM meta WHERE id = 1`)
const storedBaseId = stored.rows[0]?.base_id as string | undefined
if (storedBaseId !== input.baseId) {
throw new Error(
@@ -44,18 +44,18 @@ export async function ensureIndexMeta(executor: SqliteExecutor, input: IndexMeta
* The store-open path compares this to {@link KNOWLEDGE_INDEX_SCHEMA_VERSION}: a
* non-null mismatch means an old layout that must be rebuilt before the DDL is applied.
*/
export async function readIndexSchemaVersion(executor: SqliteExecutor): Promise<number | null> {
const hasMeta = await executor.execute(`SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'meta'`)
export function readIndexSchemaVersion(executor: SqliteExecutor): number | null {
const hasMeta = executor.execute(`SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'meta'`)
if (hasMeta.rows.length === 0) {
return null
}
const result = await executor.execute(`SELECT schema_version FROM meta WHERE id = 1`)
const result = executor.execute(`SELECT schema_version FROM meta WHERE id = 1`)
const version = result.rows[0]?.schema_version
return typeof version === 'number' ? version : null
}
/** Whether the index database holds at least one material row (store-open diagnostics probe). */
export async function hasAnyMaterial(executor: SqliteExecutor): Promise<boolean> {
const result = await executor.execute(`SELECT 1 FROM material LIMIT 1`)
export function hasAnyMaterial(executor: SqliteExecutor): boolean {
const result = executor.execute(`SELECT 1 FROM material LIMIT 1`)
return result.rows.length > 0
}

View File

@@ -214,9 +214,9 @@ export const KNOWLEDGE_INDEX_SCHEMA_STATEMENTS: readonly string[] = [
* Does NOT insert the `meta` row — that requires a runtime value (the base id)
* and is owned by the store-open path.
*/
export async function createKnowledgeIndexSchema(executor: SqliteExecutor): Promise<void> {
export function createKnowledgeIndexSchema(executor: SqliteExecutor): void {
for (const statement of KNOWLEDGE_INDEX_SCHEMA_STATEMENTS) {
await executor.execute(statement)
executor.execute(statement)
}
}
@@ -249,9 +249,9 @@ export const KNOWLEDGE_INDEX_DROP_STATEMENTS: readonly string[] = [
* the base simply re-indexes. Drops run first (autocommit per statement), then the
* current DDL is reapplied. Callers must restamp `meta` afterwards (the DROP removes it).
*/
export async function resetKnowledgeIndexSchema(executor: SqliteExecutor): Promise<void> {
export function resetKnowledgeIndexSchema(executor: SqliteExecutor): void {
for (const statement of KNOWLEDGE_INDEX_DROP_STATEMENTS) {
await executor.execute(statement)
executor.execute(statement)
}
await createKnowledgeIndexSchema(executor)
createKnowledgeIndexSchema(executor)
}

View File

@@ -6,6 +6,14 @@
* sqlite-vec today) with zero user migration. Only the driver and VectorIndex
* adapters are engine-specific; everything above them — the schema DDL
* (schema.ts) and the store's queries — is shared, engine-neutral SQL.
*
* Synchronous by design, mirroring `DbService.withWriteTx` (src/main/data/db/DbService.ts):
* better-sqlite3 is a single synchronous connection with no I/O wait, so an `async`
* surface here would be pure libsql-era residue. It is not just cosmetic — a
* transaction callback that actually awaits real async work would yield the event
* loop while `BEGIN`..`COMMIT` is open, letting an unrelated read on the same
* connection observe uncommitted rows. A synchronous `transaction<T>(fn: (tx) => T): T`
* makes that categorically impossible: `fn` runs to completion in one JS turn.
*/
/** A value bindable to a statement parameter or read back from a result column. */
@@ -13,11 +21,13 @@ export type SqlValue = string | number | bigint | boolean | Uint8Array | ArrayBu
export interface SqlQueryResult {
rows: Array<Record<string, SqlValue>>
/** Rows inserted/updated/deleted by this statement (0 for a read). */
changes: number
}
/** Runs a single statement. Implemented by both the driver and a transaction handle. */
export interface SqliteExecutor {
execute(sql: string, args?: SqlValue[]): Promise<SqlQueryResult>
execute(sql: string, args?: SqlValue[]): SqlQueryResult
}
/** A handle valid only inside SqliteDriver.transaction(); same surface as the driver. */
@@ -33,11 +43,13 @@ export interface SqliteReclaimOutcome {
export interface SqliteDriver extends SqliteExecutor {
/**
* Run `fn` inside a single write transaction. Commits when `fn` resolves,
* rolls back and rethrows when it rejects — preserving the atomic-replace
* Run `fn` inside a single write transaction. Commits when `fn` returns,
* rolls back and rethrows when it throws — preserving the atomic-replace
* semantics rebuildMaterial relies on (no mixed old/new rows ever visible).
* `fn` MUST be synchronous — the better-sqlite3 backing implementation throws
* if it returns a Promise (see BetterSqlite3Driver).
*/
transaction<T>(fn: (tx: SqliteTransaction) => Promise<T>): Promise<T>
transaction<T>(fn: (tx: SqliteTransaction) => T): T
/**
* Return free space left by deletes to the OS: always checkpoint+truncate the
* WAL, and VACUUM the main file when its freelist has grown large (both a big
@@ -49,7 +61,7 @@ export interface SqliteDriver extends SqliteExecutor {
* driver's writes; the VACUUM blocks the calling thread for the whole-file
* rewrite, which is why the threshold gates it to large deletes.
*/
reclaim(preVacuumStatements?: readonly string[]): Promise<SqliteReclaimOutcome>
reclaim(preVacuumStatements?: readonly string[]): SqliteReclaimOutcome
/**
* Whether {@link close} has been called. Lets a caller tell an operation that
* failed because the store was closed mid-flight (concurrent base deletion or
@@ -57,7 +69,7 @@ export interface SqliteDriver extends SqliteExecutor {
* instead of leaking an opaque driver error.
*/
isClosed(): boolean
close(): Promise<void>
close(): void
}
/** One brute-force vector match: an embedding row and its distance to the query. */

View File

@@ -173,6 +173,16 @@ const MessageList = () => {
messageListRef.current?.scrollToBottom('instant')
}, [])
// Navigation buttons scroll through the virtua-aware runtime handle (smooth,
// remeasure-safe) rather than a raw scrollTo on the virtualized scroller.
const navigateToTop = useCallback(() => {
messageListRef.current?.scrollToTop('smooth')
}, [])
const navigateToBottom = useCallback(() => {
messageListRef.current?.scrollToBottom('smooth')
}, [])
const scrollToMessageById = useCallback((messageId: string) => {
const target = messageByIdRef.current.get(messageId)
if (!target) return
@@ -559,7 +569,13 @@ const MessageList = () => {
<MessageOutline message={activeOutlineMessage} multiModelMessageStyle={activeOutline.multiModelMessageStyle} />
)}
{messageNavigation === 'buttons' && (
<MessageNavigation containerId="messages" messages={messages} scrollToMessageId={scrollToMessageById} />
<MessageNavigation
containerId="messages"
messages={messages}
scrollToMessageId={scrollToMessageById}
scrollToTop={navigateToTop}
scrollToBottom={navigateToBottom}
/>
)}
{meta.selectionLayer && (
<SelectionBox

View File

@@ -18,6 +18,7 @@ import {
} from '../types'
const scrollToBottom = vi.fn()
const scrollToTop = vi.fn()
const scrollToKey = vi.fn()
const messageVirtualListMocks = vi.hoisted(() => ({
deferScrollContainerReady: false,
@@ -179,6 +180,7 @@ vi.mock('../list/MessageVirtualList', async () => {
handleRef as Ref<MessageVirtualListHandle>,
() => ({
scrollToBottom,
scrollToTop,
scrollToKey,
isAtBottom: () => false,
getScrollElement: () => messageVirtualListMocks.scrollElement
@@ -259,6 +261,7 @@ const renderMessageList = (messages: MessageListItem[]) =>
describe('MessageList', () => {
beforeEach(() => {
scrollToBottom.mockClear()
scrollToTop.mockClear()
scrollToKey.mockClear()
vi.mocked(captureScrollable).mockReset()
vi.mocked(captureScrollableAsDataURL).mockReset()

View File

@@ -30,24 +30,14 @@ describe('useMessageLeafCapabilities', () => {
mockSafeOpen.mockResolvedValue(undefined)
})
it('loads external apps for ordinary text parts that mention inline absolute paths', () => {
const partsByMessageId: Record<string, CherryMessagePart[]> = {
message: [{ type: 'text', text: 'Open `/Users/example/project/App.tsx`.' } as CherryMessagePart]
}
renderHook(() => useMessageLeafCapabilities({ partsByMessageId }))
expect(mockUseExternalApps).toHaveBeenCalledWith({ enabled: true })
})
it('does not load external apps for text parts without local path hints', () => {
it('loads external apps for the message list regardless of inline path hints', () => {
const partsByMessageId: Record<string, CherryMessagePart[]> = {
message: [{ type: 'text', text: 'plain response' } as CherryMessagePart]
}
renderHook(() => useMessageLeafCapabilities({ partsByMessageId }))
expect(mockUseExternalApps).toHaveBeenCalledWith({ enabled: false })
expect(mockUseExternalApps).toHaveBeenCalledWith()
})
it('opens shared attachment files through safeOpen', async () => {

View File

@@ -1,6 +1,5 @@
import { useQuery } from '@data/hooks/useDataApi'
import type { MessageListActions, MessageListState } from '@renderer/components/chat/messages/types'
import { containsInlineFilePath } from '@renderer/components/chat/messages/utils/filePath'
import { useAttachment } from '@renderer/hooks/useAttachment'
import { useExternalApps } from '@renderer/hooks/useExternalApps'
import type { FileMetadata } from '@renderer/types/file'
@@ -51,14 +50,6 @@ function isMcpToolPart(part: CherryMessagePart): boolean {
return tool?.type === 'mcp'
}
function hasExternalEditorPathHint(part: CherryMessagePart): boolean {
const partType = (part as { type?: string }).type
if (partType === 'dynamic-tool' || !!partType?.startsWith('tool-')) return true
if (partType !== 'text') return false
return containsInlineFilePath((part as { text?: string }).text)
}
function fileMetadataToHandle(file: FileMetadata): FileHandle {
if (file.path) {
try {
@@ -111,12 +102,8 @@ export function useMessageLeafCapabilities({
() => Object.values(partsByMessageId).some((parts) => parts.some(isMcpToolPart)),
[partsByMessageId]
)
const hasExternalEditorPathHints = useMemo(
() => Object.values(partsByMessageId).some((parts) => parts.some(hasExternalEditorPathHint)),
[partsByMessageId]
)
const { data: mcpServersData } = useQuery('/mcp-servers', { enabled: hasMcpToolParts })
const { data: externalApps } = useExternalApps({ enabled: hasExternalEditorPathHints })
const { data: externalApps } = useExternalApps()
const mcpServers = useMemo(() => mcpServersData?.items ?? [], [mcpServersData])
const externalCodeEditors = useMemo(
() => externalApps?.filter((app) => app.tags.includes('code-editor')) ?? [],

View File

@@ -35,6 +35,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
const userName = renderConfig.userName
const assistantProfile = meta.assistantProfile
const avatar = meta.userProfile?.avatar ?? ''
const { updateMessageUiState } = actions
const { setTimeoutTimer } = useTimer()
const messagesListRef = useRef<HTMLDivElement>(null)
@@ -101,7 +102,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
const groupMessages = messages.filter((m) => m.parentId === message.parentId)
if (groupMessages.length > 1) {
for (const m of groupMessages) {
actions.updateMessageUiState?.(m.id, { foldSelected: m.id === message.id })
updateMessageUiState?.(m.id, { foldSelected: m.id === message.id })
}
setTimeoutTimer(
@@ -116,7 +117,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
)
}
},
[actions.updateMessageUiState, messages, setTimeoutTimer]
[messages, setTimeoutTimer, updateMessageUiState]
)
const scrollToMessage = useCallback(
@@ -125,7 +126,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
const siblings = messages.filter((m) => m.role === 'assistant' && m.parentId === message.parentId)
if (siblings.length > 1) {
for (const sibling of siblings) {
actions.updateMessageUiState?.(sibling.id, { foldSelected: sibling.id === message.id })
updateMessageUiState?.(sibling.id, { foldSelected: sibling.id === message.id })
}
}
}
@@ -146,7 +147,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
}
scrollIntoView(messageElement, { behavior: 'smooth', block: 'start', container: 'nearest' })
},
[actions, messages, scrollToMessageId, setSelectedMessage]
[messages, scrollToMessageId, setSelectedMessage, updateMessageUiState]
)
const scrollToBottom = useCallback(() => {
@@ -328,14 +329,14 @@ const MessageLineContainer = ({
<div
ref={ref}
className={[
'group fixed right-[13px] z-999 flex w-[14px] translate-y-[-50%] select-none items-center justify-end overflow-hidden text-[5px] hover:w-[500px] hover:overflow-y-hidden hover:overflow-x-visible',
'group absolute right-3.25 z-20 flex w-3.5 translate-y-[-50%] select-none items-center justify-end overflow-hidden text-[5px] hover:w-125 hover:overflow-y-hidden hover:overflow-x-visible',
className
]
.filter(Boolean)
.join(' ')}
style={{
top: 'calc(50% - var(--status-bar-height) - 10px)',
maxHeight: $height ? `${$height - 20}px` : 'calc(100% - var(--status-bar-height) * 2 - 20px)',
top: '50%',
maxHeight: $height ? `${$height - 20}px` : 'calc(100% - 20px)',
...style
}}
{...props}

View File

@@ -23,13 +23,21 @@ interface MessageNavigationProps {
containerId: string
messages: MessageListItem[]
scrollToMessageId: (messageId: string) => void
scrollToTop: () => void
scrollToBottom: () => void
}
const getScrollContainer = (container: HTMLElement | null): HTMLElement | null => {
return container?.querySelector<HTMLElement>('[data-message-virtual-list-scroller]') ?? container
}
const MessageNavigation: FC<MessageNavigationProps> = ({ containerId, messages, scrollToMessageId }) => {
const MessageNavigation: FC<MessageNavigationProps> = ({
containerId,
messages,
scrollToMessageId,
scrollToTop,
scrollToBottom
}) => {
const { t } = useTranslation()
const [isVisible, setIsVisible] = useState(false)
const timerKey = 'hide'
@@ -73,16 +81,6 @@ const MessageNavigation: FC<MessageNavigationProps> = ({ containerId, messages,
scheduleHide(500)
}, [scheduleHide])
const scrollToTop = () => {
const scrollContainer = getScrollContainer(document.getElementById(containerId))
scrollContainer?.scrollTo({ top: 0, behavior: 'smooth' })
}
const scrollToBottom = () => {
const scrollContainer = getScrollContainer(document.getElementById(containerId))
scrollContainer?.scrollTo({ top: scrollContainer.scrollHeight, behavior: 'smooth' })
}
const getCurrentVisibleIndex = (direction: 'up' | 'down') => {
const userMessages = messages.filter((message) => message.role === 'user')
const assistantMessages = messages.filter((message) => message.role === 'assistant')

View File

@@ -0,0 +1,76 @@
// @vitest-environment jsdom
import '@testing-library/jest-dom/vitest'
import { render } from '@testing-library/react'
import type { ReactNode } from 'react'
import { describe, expect, it, vi } from 'vitest'
import type * as MessageTypes from '../../types'
import type { MessageListItem } from '../../types'
import MessageAnchorLine from '../MessageAnchorLine'
vi.mock('@cherrystudio/ui', () => ({
Avatar: ({ children, ...props }: { children?: ReactNode }) => <div {...props}>{children}</div>,
AvatarFallback: ({ children }: { children?: ReactNode }) => <span>{children}</span>,
AvatarImage: () => null,
EmojiAvatar: ({ children, ...props }: { children?: ReactNode }) => <span {...props}>{children}</span>
}))
vi.mock('@renderer/hooks/useTheme', () => ({
useTheme: () => ({ theme: 'light' })
}))
vi.mock('@renderer/hooks/useTimer', () => ({
useTimer: () => ({
setTimeoutTimer: (_key: string, callback: () => void) => callback()
})
}))
vi.mock('@renderer/utils/model', () => ({
getModelLogo: () => null
}))
vi.mock('@renderer/utils/naming', () => ({
firstLetter: (value: string) => value[0] ?? '',
isEmoji: () => false,
removeLeadingEmoji: (value: string) => value
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({ t: (key: string) => key })
}))
vi.mock('../../blocks', () => ({
usePartsMap: () => ({})
}))
vi.mock('../../MessageListProvider', async () => {
const { defaultMessageRenderConfig } = await vi.importActual<typeof MessageTypes>('../../types')
return {
useMessageListActions: () => ({}),
useMessageListMeta: () => ({}),
useMessageRenderConfig: () => defaultMessageRenderConfig
}
})
const messages: MessageListItem[] = [
{
id: 'message-1',
role: 'user',
topicId: 'topic-1',
parentId: null,
createdAt: '2026-07-02T00:00:00.000Z',
status: 'success'
}
]
describe('MessageAnchorLine', () => {
it('keeps the anchor rail scoped inside the message list layer', () => {
const { container } = render(<MessageAnchorLine messages={messages} />)
const anchorRail = container.firstElementChild
expect(anchorRail).toHaveClass('absolute', 'z-20')
expect(anchorRail).not.toHaveClass('fixed', 'z-999')
})
})

View File

@@ -50,6 +50,46 @@ const setRect = (element: Element, rect: Partial<DOMRect>) => {
}))
}
const renderNavigation = (messages: MessageListItem[], visibleMessageIds: string[] = []) => {
const scrollToMessageId = vi.fn()
const scrollToTop = vi.fn()
const scrollToBottom = vi.fn()
const { container } = render(
<>
<div id="messages">
<div data-message-virtual-list-scroller>
{messages.map((message) => (
<div key={message.id} id={`message-${message.id}`} />
))}
</div>
</div>
<MessageNavigation
containerId="messages"
messages={messages}
scrollToMessageId={scrollToMessageId}
scrollToTop={scrollToTop}
scrollToBottom={scrollToBottom}
/>
</>
)
setRect(container.querySelector('[data-message-virtual-list-scroller]') as HTMLElement, {
bottom: 500,
height: 500,
top: 0
})
for (const message of messages) {
setRect(document.getElementById(`message-${message.id}`) as HTMLElement, {
bottom: visibleMessageIds.includes(message.id) ? 220 : -100,
height: 100,
top: visibleMessageIds.includes(message.id) ? 120 : -200
})
}
return { scrollToBottom, scrollToMessageId, scrollToTop }
}
describe('MessageNavigation', () => {
it('scrolls to message ids from the full message list, not only rendered DOM nodes', () => {
const scrollToMessageId = vi.fn()
@@ -68,7 +108,13 @@ describe('MessageNavigation', () => {
<div id="message-user-2" />
</div>
</div>
<MessageNavigation containerId="messages" messages={messages} scrollToMessageId={scrollToMessageId} />
<MessageNavigation
containerId="messages"
messages={messages}
scrollToMessageId={scrollToMessageId}
scrollToTop={vi.fn()}
scrollToBottom={vi.fn()}
/>
</>
)
@@ -87,4 +133,78 @@ describe('MessageNavigation', () => {
expect(scrollToMessageId).toHaveBeenCalledWith('user-3')
})
it('delegates the top and bottom buttons to the runtime scroll callbacks', () => {
const scrollToTop = vi.fn()
const scrollToBottom = vi.fn()
render(
<>
<div id="messages">
<div data-message-virtual-list-scroller />
</div>
<MessageNavigation
containerId="messages"
messages={[createMessage('user-1', 'user')]}
scrollToMessageId={vi.fn()}
scrollToTop={scrollToTop}
scrollToBottom={scrollToBottom}
/>
</>
)
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.top' }))
expect(scrollToTop).toHaveBeenCalledTimes(1)
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.bottom' }))
expect(scrollToBottom).toHaveBeenCalledTimes(1)
})
it.each([
{
name: 'when there are no messages',
messages: []
},
{
name: 'when no message is visible',
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')]
},
{
name: 'when the first user message is already visible',
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')],
visibleMessageIds: ['user-1']
}
])('delegates next-message fallback to runtime scrollToBottom $name', ({ messages, visibleMessageIds }) => {
const { scrollToBottom, scrollToMessageId, scrollToTop } = renderNavigation(messages, visibleMessageIds)
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.next' }))
expect(scrollToBottom).toHaveBeenCalledTimes(1)
expect(scrollToTop).not.toHaveBeenCalled()
expect(scrollToMessageId).not.toHaveBeenCalled()
})
it.each([
{
name: 'when there are no messages',
messages: []
},
{
name: 'when no message is visible',
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')]
},
{
name: 'when the last user message is already visible',
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')],
visibleMessageIds: ['user-2']
}
])('delegates prev-message fallback to runtime scrollToTop $name', ({ messages, visibleMessageIds }) => {
const { scrollToBottom, scrollToMessageId, scrollToTop } = renderNavigation(messages, visibleMessageIds)
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.prev' }))
expect(scrollToTop).toHaveBeenCalledTimes(1)
expect(scrollToBottom).not.toHaveBeenCalled()
expect(scrollToMessageId).not.toHaveBeenCalled()
})
})

View File

@@ -443,6 +443,202 @@ describe('useChatVirtualizerRuntime', () => {
expect(runtime!.isScrollToBottomButtonVisible).toBe(false)
})
it('scrolls to top instantly and releases the top anchor', () => {
const callbacks: ResizeObserverCallback[] = []
const restoreResizeObserver = installResizeObserverMock(callbacks)
const raf = installQueuedAnimationFrame()
try {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
handle = nextHandle
}
let scrollTop = 0
const view = render(
<RuntimeDomProbe
items={['message-a']}
handleRef={handleRef}
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
const scroller = runtime!.scrollerRef.current!
Object.defineProperty(scroller, 'scrollTop', {
configurable: true,
get: () => scrollTop,
set: (value) => {
scrollTop = value
}
})
Object.defineProperty(scroller, 'scrollHeight', { configurable: true, get: () => 700 })
Object.defineProperty(scroller, 'clientHeight', { configurable: true, get: () => 400 })
runtime!.vlistHandleRef.current = createHandle({ getItemOffset: vi.fn(() => 300) })
view.rerender(
<RuntimeDomProbe
items={['message-a']}
handleRef={handleRef}
preserveScrollAnchor
scrollToTopKey="message-a"
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
raf.tick()
expect(runtime!.wrappedItems.some((item) => item.kind === 'spacer')).toBe(true)
act(() => {
scrollTop = 300
handle!.scrollToTop('instant')
})
expect(scrollTop).toBe(0)
act(() => callbacks[0]?.([], {} as ResizeObserver))
expect(scrollTop).toBe(0)
} finally {
restoreResizeObserver()
raf.restore()
}
})
it('scrolls to top smoothly with the RAF-driven runtime scroller', () => {
const raf = installQueuedAnimationFrame()
try {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
handle = nextHandle
}
let scrollTop = 500
render(
<RuntimeDomProbe
items={['message-a']}
handleRef={handleRef}
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
const scroller = runtime!.scrollerRef.current!
Object.defineProperty(scroller, 'scrollTop', {
configurable: true,
get: () => scrollTop,
set: (value) => {
scrollTop = value
}
})
act(() => {
handle!.scrollToTop('smooth')
})
expect(scrollTop).toBe(500)
raf.tick()
expect(scrollTop).toBeGreaterThan(0)
expect(scrollTop).toBeLessThan(500)
raf.tick(50)
expect(scrollTop).toBe(0)
} finally {
raf.restore()
}
})
it('replaces an in-flight smooth scroll when scrolling to top', () => {
const raf = installQueuedAnimationFrame()
try {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
handle = nextHandle
}
let scrollTop = 0
render(
<RuntimeDomProbe
items={['message-a']}
handleRef={handleRef}
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
const scroller = runtime!.scrollerRef.current!
Object.defineProperty(scroller, 'scrollTop', {
configurable: true,
get: () => scrollTop,
set: (value) => {
scrollTop = value
}
})
setElementMetric(scroller, 'scrollHeight', () => 1200)
setElementMetric(scroller, 'clientHeight', () => 400)
act(() => {
handle!.scrollToBottom('smooth')
})
raf.tick()
const bottomScrollProgress = scrollTop
expect(bottomScrollProgress).toBeGreaterThan(0)
act(() => {
handle!.scrollToTop('smooth')
})
raf.tick()
expect(scrollTop).toBeLessThan(bottomScrollProgress)
raf.tick(50)
expect(scrollTop).toBe(0)
} finally {
raf.restore()
}
})
it('replaces an in-flight smooth scroll when scrolling to bottom', () => {
const raf = installQueuedAnimationFrame()
try {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
handle = nextHandle
}
let scrollTop = 800
render(
<RuntimeDomProbe
items={['message-a']}
handleRef={handleRef}
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
const scroller = runtime!.scrollerRef.current!
Object.defineProperty(scroller, 'scrollTop', {
configurable: true,
get: () => scrollTop,
set: (value) => {
scrollTop = value
}
})
setElementMetric(scroller, 'scrollHeight', () => 1200)
setElementMetric(scroller, 'clientHeight', () => 400)
act(() => {
handle!.scrollToTop('smooth')
})
raf.tick()
const topScrollProgress = scrollTop
expect(topScrollProgress).toBeLessThan(800)
act(() => {
handle!.scrollToBottom('smooth')
})
raf.tick()
expect(scrollTop).toBeGreaterThan(topScrollProgress)
raf.tick(50)
expect(scrollTop).toBe(800)
} finally {
raf.restore()
}
})
it('resets bottom-follow state when pinning a message to the viewport top', () => {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null
@@ -479,6 +675,46 @@ describe('useChatVirtualizerRuntime', () => {
expect(handle!.isAtBottom()).toBe(false)
})
it('does not pin a follow-up steered into a still-streaming turn', () => {
const callbacks: ResizeObserverCallback[] = []
const restoreResizeObserver = installResizeObserverMock(callbacks)
const raf = installQueuedAnimationFrame()
try {
let runtime: ChatVirtualizerRuntime<string> | undefined
// Render 1: a turn is already streaming (preserveScrollAnchor is true) and no
// new user-message key has arrived yet.
const view = render(
<RuntimeDomProbe items={['user-a']} preserveScrollAnchor onRuntime={(nextRuntime) => (runtime = nextRuntime)} />
)
const getSpacerHeight = () => runtime!.wrappedItems.find((item) => item.kind === 'spacer')?.height ?? 0
const scroller = runtime!.scrollerRef.current!
Object.defineProperty(scroller, 'scrollTop', { configurable: true, get: () => 0 })
Object.defineProperty(scroller, 'scrollHeight', { configurable: true, get: () => 900 + getSpacerHeight() })
Object.defineProperty(scroller, 'clientHeight', { configurable: true, get: () => 400 })
runtime!.vlistHandleRef.current = createHandle({ getItemOffset: vi.fn(() => 300) })
// Render 2: a queued follow-up is steered into the live turn — a new user
// message (`user-b`) arrives while streaming continues. Because a turn was
// already streaming just before it, the message must NOT pin to the top, so
// no anchor spacer is created (the pin path is what created the instability).
view.rerender(
<RuntimeDomProbe
items={['user-a', 'user-b']}
preserveScrollAnchor
scrollToTopKey="user-b"
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
/>
)
raf.tick()
expect(getSpacerHeight()).toBe(0)
} finally {
restoreResizeObserver()
raf.restore()
}
})
it('keeps bottom-follow suppressed while the user is still pinned to the top', () => {
let runtime: ChatVirtualizerRuntime<string> | undefined
let handle: MessageVirtualListHandle | null = null

View File

@@ -36,6 +36,7 @@ import { useSmoothScrollAnimation } from './useSmoothScrollAnimation'
export interface MessageVirtualListHandle {
scrollToBottom(behavior?: ScrollBehavior): void
scrollToTop(behavior?: ScrollBehavior): void
scrollToKey(key: string, align?: 'start' | 'center' | 'end'): void
isAtBottom(): boolean
getScrollElement(): HTMLElement | null
@@ -402,6 +403,11 @@ export function useChatVirtualizerRuntime<T>({
const lastScrollToTopKeyRef = useRef<string | undefined>(undefined)
const didMountForScrollKeyRef = useRef(false)
// The committed `preserveScrollAnchor` from the previous render — i.e. whether a
// turn was already streaming just before the current commit. Lets the pin effect
// tell a fresh idle→new-turn send from a mid-stream insertion. A trailing effect
// (below) keeps it in sync AFTER the pin effect has read the prior value.
const wasStreamingBeforeUserMessageRef = useRef(preserveScrollAnchor)
useEffect(() => {
const previous = lastScrollToTopKeyRef.current
@@ -411,6 +417,12 @@ export function useChatVirtualizerRuntime<T>({
return
}
if (!scrollToTopKey || scrollToTopKey === previous) return
// A new user message appeared. Only pin it to the top when it STARTS a fresh
// turn (the topic was idle just before it). If a turn was already streaming —
// a queued follow-up steered into the live turn — pinning the new message to
// the top would yank the view and fight the previous assistant's still-growing
// response (the instability we're fixing). Leave scroll to bottom-follow.
if (wasStreamingBeforeUserMessageRef.current) return
const idx = findDataIndexByKey(scrollToTopKey)
if (idx < 0) return
anchor.pinTo(idx)
@@ -420,6 +432,13 @@ export function useChatVirtualizerRuntime<T>({
userTookControlRef.current = false
}, [anchor, atBottom, findDataIndexByKey, scrollToTopKey])
// Sync the "was a turn already streaming" marker AFTER the pin effect above has
// read the previous render's value. Runs every commit so the next new-user-
// message commit sees whether streaming was in progress when it arrived.
useEffect(() => {
wasStreamingBeforeUserMessageRef.current = preserveScrollAnchor
})
// Initial scroll on mount is owned by `useScrollPositionMemory` above: it
// restores the saved anchor for this topic, or scrolls to the newest message
// when there is nothing to restore.
@@ -566,12 +585,10 @@ export function useChatVirtualizerRuntime<T>({
if (!el) return
const target = getRealBottom(el, anchor.spacerHeight)
if (behavior === 'smooth') {
if (!smoothScroll.isAnimating()) {
smoothScroll.scrollTo(() => {
const current = scrollerRef.current
return current ? getRealBottom(current, bottomFollowInsetRef.current) : 0
})
}
smoothScroll.scrollTo(() => {
const current = scrollerRef.current
return current ? getRealBottom(current, bottomFollowInsetRef.current) : 0
})
} else {
smoothScroll.cancel()
el.scrollTop = target
@@ -582,10 +599,31 @@ export function useChatVirtualizerRuntime<T>({
[anchor, atBottom, hideScrollToBottomButton, smoothScroll]
)
const scrollToTop = useCallback(
(behavior: ScrollBehavior = 'instant') => {
// Explicit scroll-to-top releases any anchor pin — the caller wants the
// absolute top of the loaded content, not the pinned user-message position.
anchor.release()
const el = scrollerRef.current
if (!el) return
if (behavior === 'smooth') {
// Drive the scroll frame-by-frame (RAF) rather than native
// `behavior: 'smooth'`: virtua remeasures items entering the viewport
// and compensates scrollTop, which cancels a native animation mid-flight.
smoothScroll.scrollTo(() => 0)
} else {
smoothScroll.cancel()
el.scrollTop = 0
}
},
[anchor, smoothScroll]
)
useImperativeHandle(
handleRef,
(): MessageVirtualListHandle => ({
scrollToBottom,
scrollToTop,
scrollToKey: (key, align = 'start') => {
const handle = vlistHandleRef.current
const idx = findDataIndexByKey(key)
@@ -596,7 +634,7 @@ export function useChatVirtualizerRuntime<T>({
isAtBottom: atBottom.isAtBottom,
getScrollElement: () => scrollerRef.current
}),
[anchor, atBottom.isAtBottom, findDataIndexByKey, scrollToBottom]
[anchor, atBottom.isAtBottom, findDataIndexByKey, scrollToBottom, scrollToTop]
)
return {

View File

@@ -43,9 +43,3 @@ export function isInlineFilePath(value: string): boolean {
WORKSPACE_RELATIVE_FILE_PATH_PATTERN.test(normalizedPath)
)
}
export function containsInlineFilePath(value: string | undefined): boolean {
if (!value) return false
return value.split(/\s+/).some((token) => isInlineFilePath(token))
}

View File

@@ -1,3 +1,4 @@
import { usePreference } from '@data/hooks/usePreference'
import { loggerService } from '@logger'
import type { ResolvedAction } from '@renderer/components/chat/actions/actionTypes'
import {
@@ -17,7 +18,7 @@ import { usePins } from '@renderer/hooks/usePins'
import { mapApiTopicToRendererTopic, useTopicMutations } from '@renderer/hooks/useTopic'
import type { Topic } from '@renderer/types/topic'
import { formatErrorMessageWithPrefix } from '@renderer/utils/error'
import { Bot, Edit3, PinIcon, PinOffIcon, Plus, Trash2 } from 'lucide-react'
import { Bot, Edit3, PinIcon, PinOffIcon, Plus, Tags, Trash2 } from 'lucide-react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -27,6 +28,7 @@ const logger = loggerService.withContext('AssistantResourceList')
const ASSISTANT_ENTITY_EDIT_ACTION_ID = 'assistant-entity.edit'
const ASSISTANT_ENTITY_TOGGLE_PIN_ACTION_ID = 'assistant-entity.toggle-pin'
const ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID = 'assistant-entity.toggle-tag-grouping'
const ASSISTANT_ENTITY_DELETE_ACTION_ID = 'assistant-entity.delete'
type AssistantResourceListProps = {
@@ -52,6 +54,8 @@ export function AssistantResourceList({
onActiveAssistantDeleted
}: AssistantResourceListProps) {
const { t } = useTranslation()
const [assistantSortType, setAssistantSortType] = usePreference('assistant.tab.sort_type')
const isTagGrouping = assistantSortType === 'tags'
const {
assistants,
isLoading: isAssistantsLoading,
@@ -95,6 +99,7 @@ export function AssistantResourceList({
name: assistant.name,
orderKey: assistant.orderKey,
pinned: assistantPinnedIdSet.has(assistant.id),
tag: assistant.tags?.[0]?.name,
icon: assistant.emoji ? (
<EmojiIcon emoji={assistant.emoji} size={24} fontSize={14} className="mr-0" />
) : (
@@ -227,6 +232,15 @@ export function AssistantResourceList({
availability: { visible: true, enabled: !isAssistantPinActionDisabled },
children: []
},
{
id: ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID,
label: isTagGrouping ? t('assistants.tags.ungroup') : t('assistants.tags.group_by'),
icon: <Tags size={14} />,
order: 25,
danger: false,
availability: { visible: true, enabled: true },
children: []
},
{
id: ASSISTANT_ENTITY_DELETE_ACTION_ID,
label: t('assistants.delete.title'),
@@ -239,7 +253,7 @@ export function AssistantResourceList({
}
]
},
[assistantPinnedIdSet, deletingAssistantId, isAssistantPinActionDisabled, t]
[assistantPinnedIdSet, deletingAssistantId, isAssistantPinActionDisabled, isTagGrouping, t]
)
const handleContextMenuAction = useCallback(
@@ -252,11 +266,15 @@ export function AssistantResourceList({
void handleToggleAssistantPin(item.id)
return
}
if (action.id === ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID) {
void setAssistantSortType(isTagGrouping ? 'list' : 'tags')
return
}
if (action.id === ASSISTANT_ENTITY_DELETE_ACTION_ID) {
void handleDeleteAssistant(item.id)
}
},
[handleDeleteAssistant, handleToggleAssistantPin, openAssistantEditor]
[handleDeleteAssistant, handleToggleAssistantPin, isTagGrouping, openAssistantEditor, setAssistantSortType]
)
return (
@@ -268,6 +286,7 @@ export function AssistantResourceList({
status={listStatus}
ariaLabel={t('assistants.abbr')}
defaultGroupLabel={t('assistants.abbr')}
groupByTag={isTagGrouping}
addIcon={<Plus />}
addLabel={t('chat.add.assistant.title')}
onAdd={onAddAssistant ?? (() => onStartDraftAssistant(null))}

View File

@@ -26,6 +26,8 @@ export type ResourceEntityRailItem = {
* It does not affect visibility — an entity with no resources stays hidden whether pinned or not.
*/
pinned?: boolean
/** Single user tag name. Only consulted when the rail runs with `groupByTag`; undefined → "未分组". */
tag?: string
}
// Pinned entities float into a "已固定" section at the top; the rest sit under the "助手" / "智能体"
@@ -36,6 +38,27 @@ const ENTITY_RAIL_PINNED_SECTION_ID = 'resource-entity-rail:section:pinned'
const ENTITY_RAIL_DEFAULT_SECTION_ID = 'resource-entity-rail:section:default'
const ENTITY_RAIL_PINNED_GROUP_ID = 'resource-entity-rail:group:pinned'
const ENTITY_RAIL_DEFAULT_GROUP_ID = 'resource-entity-rail:group:default'
// When `groupByTag` is on, each tag name becomes its own collapsible section below the pinned one;
// untagged entities collapse together under a distinct internal bucket.
const ENTITY_RAIL_TAG_SECTION_PREFIX = 'resource-entity-rail:section:'
const ENTITY_RAIL_TAG_GROUP_PREFIX = 'resource-entity-rail:group:'
const ENTITY_RAIL_UNTAGGED_KEY = JSON.stringify(['untagged'])
function getEntityRailTagBucketKey(tag: string | undefined) {
return tag ? JSON.stringify(['tag', tag]) : ENTITY_RAIL_UNTAGGED_KEY
}
function getEntityRailTagGroupingRank(item: ResourceEntityRailItem) {
if (item.pinned) return 0
return item.tag ? 2 : 1
}
function sortEntityRailItemsForTagGrouping<T extends ResourceEntityRailItem>(items: readonly T[]): T[] {
return items
.map((item, index) => ({ item, index, rank: getEntityRailTagGroupingRank(item) }))
.sort((a, b) => a.rank - b.rank || a.index - b.index)
.map(({ item }) => item)
}
export type ResourceEntityRailProps<T extends ResourceEntityRailItem, TActionContext = unknown> = {
addIcon?: ReactNode
@@ -43,6 +66,12 @@ export type ResourceEntityRailProps<T extends ResourceEntityRailItem, TActionCon
ariaLabel: string
/** Header for the non-pinned group ("助手" for assistants, "智能体" for agents). */
defaultGroupLabel?: string
/**
* Group the non-pinned entities by their `tag` into collapsible sections (the pinned section stays
* on top). Drag-reorder is disabled while on, since `orderKey` is a single flat order. Off → the
* flat "助手"/"智能体" section.
*/
groupByTag?: boolean
emptyFallback?: ReactNode
getContextMenuActions?: (item: T) => readonly ResolvedAction<TActionContext>[]
listRef?: RefObject<HTMLDivElement | null>
@@ -82,6 +111,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
addLabel,
ariaLabel,
defaultGroupLabel,
groupByTag = false,
emptyFallback,
getContextMenuActions,
listRef,
@@ -96,6 +126,9 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
items
}: ResourceEntityRailProps<T, TActionContext>) {
const { t } = useTranslation()
// Tag grouping splits the flat order across sections, so dragging an item between tags would have
// no meaningful `orderKey` target — disable reorder entirely while grouping by tag.
const reorderEnabled = !!onReorder && !groupByTag
const fallbackListRef = useRef<HTMLDivElement>(null)
const effectiveListRef = listRef ?? fallbackListRef
const runContextMenuAction = useCallback(
@@ -173,23 +206,39 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
[getContextMenuActions, onContextMenuAction, onSelect, runContextMenuAction, t]
)
const empty = useMemo(() => emptyFallback ?? <div className="min-h-0 flex-1" />, [emptyFallback])
const providerItems = useMemo(
() => (groupByTag ? sortEntityRailItemsForTagGrouping(items) : items),
[groupByTag, items]
)
// Collapsible sections matching the modern layout's left assistant/agent layout (minus the nested
// topics/sessions): pinned entities float into "已固定" at the top, the rest sit under the
// "助手" / "智能体" section below. Section headers stay flush-left; the entity rows keep their
// avatar and read as indented beneath. The single-section case (nothing pinned) renders the flat
// list with no header, exactly like the modern layout.
const sectionBy = useMemo<(item: T) => ResourceListSection>(
() => (item) =>
item.pinned
? { id: ENTITY_RAIL_PINNED_SECTION_ID, label: t('selector.common.pinned_title') }
: { id: ENTITY_RAIL_DEFAULT_SECTION_ID, label: defaultGroupLabel ?? '' },
[defaultGroupLabel, t]
() => (item) => {
if (item.pinned) return { id: ENTITY_RAIL_PINNED_SECTION_ID, label: t('selector.common.pinned_title') }
if (groupByTag) {
const tagBucketKey = getEntityRailTagBucketKey(item.tag)
return item.tag
? { id: `${ENTITY_RAIL_TAG_SECTION_PREFIX}${tagBucketKey}`, label: item.tag }
: { id: `${ENTITY_RAIL_TAG_SECTION_PREFIX}${tagBucketKey}`, label: t('assistants.tags.untagged') }
}
return { id: ENTITY_RAIL_DEFAULT_SECTION_ID, label: defaultGroupLabel ?? '' }
},
[defaultGroupLabel, groupByTag, t]
)
// Header-less groups (one per section, distinct ids) keep entity avatars visible and stop
// drag-reorder from crossing the pinned/non-pinned boundary.
// drag-reorder from crossing the pinned/non-pinned (or per-tag) boundary.
const groupBy = useMemo<(item: T) => ResourceListGroup>(
() => (item) => ({ id: item.pinned ? ENTITY_RAIL_PINNED_GROUP_ID : ENTITY_RAIL_DEFAULT_GROUP_ID, label: '' }),
[]
() => (item) => {
if (item.pinned) return { id: ENTITY_RAIL_PINNED_GROUP_ID, label: '' }
if (groupByTag) {
return { id: `${ENTITY_RAIL_TAG_GROUP_PREFIX}${getEntityRailTagBucketKey(item.tag)}`, label: '' }
}
return { id: ENTITY_RAIL_DEFAULT_GROUP_ID, label: '' }
},
[groupByTag]
)
// Alias the compound provider to a local before rendering — same pattern as TopicResourceList/SessionResourceList.
@@ -200,7 +249,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
return (
<Provider
variant={variant}
items={items}
items={providerItems}
selectedId={selectedId}
status={status}
groupBy={groupBy}
@@ -208,15 +257,15 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
defaultGroupVisibleCount={Number.POSITIVE_INFINITY}
dragCapabilities={{
groups: false,
items: !!onReorder,
itemSameGroup: !!onReorder,
items: reorderEnabled,
itemSameGroup: reorderEnabled,
itemCrossGroup: false
}}
canDragItem={({ item }) => !!onReorder && !item.pinned}
canDragItem={({ item }) => reorderEnabled && !item.pinned}
canDropItem={({ activeItem, targetGroupId }) =>
!!onReorder && !activeItem.pinned && targetGroupId !== ENTITY_RAIL_PINNED_GROUP_ID
reorderEnabled && !activeItem.pinned && targetGroupId !== ENTITY_RAIL_PINNED_GROUP_ID
}
onReorder={onReorder}>
onReorder={reorderEnabled ? onReorder : undefined}>
<ResourceList.Frame className="h-full min-h-0" data-testid={`${variant}-entity-rail`}>
<ResourceList.Header className="gap-1">
<ResourceList.HeaderItem
@@ -241,7 +290,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
</ResourceList.Header>
<ResourceList.Body<T>
listRef={effectiveListRef}
draggable={!!onReorder}
draggable={reorderEnabled}
ariaLabel={ariaLabel}
virtualClassName="pt-1 pb-3"
errorFallback={<ResourceList.ErrorState message={t('error.boundary.default.message')} />}

View File

@@ -17,12 +17,21 @@ const agentDataMocks = vi.hoisted(() => ({
refetchAgents: vi.fn()
}))
const preferenceMocks = vi.hoisted(() => ({
sortType: 'list' as 'list' | 'tags',
setSortType: vi.fn()
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key
})
}))
vi.mock('@data/hooks/usePreference', () => ({
usePreference: () => [preferenceMocks.sortType, preferenceMocks.setSortType]
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
@@ -192,6 +201,8 @@ vi.mock('@renderer/utils/error', () => ({
describe('classic layout entity resource list actions', () => {
beforeEach(() => {
preferenceMocks.sortType = 'list'
preferenceMocks.setSortType.mockClear()
assistantDataMocks.deleteAssistant.mockResolvedValue(undefined)
assistantDataMocks.refreshTopics.mockResolvedValue(undefined)
assistantDataMocks.refetchAssistants.mockResolvedValue(undefined)
@@ -239,6 +250,33 @@ describe('classic layout entity resource list actions', () => {
expect(onStartDraftAssistant).not.toHaveBeenCalled()
})
it('toggles assistant tag grouping from the context menu (list → tags)', () => {
render(
<AssistantResourceList activeAssistantId="assistant-1" onSelectTopic={vi.fn()} onStartDraftAssistant={vi.fn()} />
)
// sort_type === 'list' → the menu offers "group by tag".
const menu = screen.getByTestId('assistant-1-context-menu')
expect(menu).toHaveTextContent('assistants.tags.group_by')
expect(menu).not.toHaveTextContent('assistants.tags.ungroup')
fireEvent.click(screen.getAllByRole('button', { name: 'assistants.tags.group_by' })[0])
expect(preferenceMocks.setSortType).toHaveBeenCalledWith('tags')
})
it('offers turning tag grouping off when already grouping (tags → list)', () => {
preferenceMocks.sortType = 'tags'
render(
<AssistantResourceList activeAssistantId="assistant-1" onSelectTopic={vi.fn()} onStartDraftAssistant={vi.fn()} />
)
expect(screen.getByTestId('assistant-1-context-menu')).toHaveTextContent('assistants.tags.ungroup')
fireEvent.click(screen.getAllByRole('button', { name: 'assistants.tags.ungroup' })[0])
expect(preferenceMocks.setSortType).toHaveBeenCalledWith('list')
})
it('uses delete-agent actions for the classic layout agent context and more menus', async () => {
const onStartMissingAgentDraft = vi.fn()
const onActiveAgentDeleted = vi.fn()

View File

@@ -65,7 +65,8 @@ vi.mock('@renderer/components/VirtualList', () => {
return rows
}
const GroupedVirtualList = ({
const GroupedVirtualListContent = ({
dragEnabled,
ref,
className,
groups,
@@ -91,6 +92,7 @@ vi.mock('@renderer/components/VirtualList', () => {
}}
role={role}
className={className}
data-draggable={dragEnabled ? 'true' : 'false'}
{...scrollerProps}>
{rows.map((row, index) => {
if (row.type === 'group-header') {
@@ -114,8 +116,8 @@ vi.mock('@renderer/components/VirtualList', () => {
return {
buildGroupedVirtualRows,
DynamicVirtualList: () => null,
GroupedSortableVirtualList: GroupedVirtualList,
GroupedVirtualList
GroupedSortableVirtualList: (props) => <GroupedVirtualListContent {...props} dragEnabled />,
GroupedVirtualList: (props) => <GroupedVirtualListContent {...props} dragEnabled={false} />
}
})
@@ -261,6 +263,88 @@ describe('ResourceEntityRail', () => {
expect(screen.getByTestId('assistant-a-icon')).toBeInTheDocument()
})
it('groups non-pinned entities into per-tag sections while keeping pinned on top', () => {
render(
<ResourceEntityRail
addLabel="New"
ariaLabel="Assistants list"
defaultGroupLabel="Assistants"
groupByTag
items={[
{ id: 'pinned-tagged', name: 'Pinned Tagged', icon: <span />, pinned: true, tag: 'work' },
{ id: 'work-a', name: 'Work A', icon: <span data-testid="work-a-icon" />, tag: 'work' },
{ id: 'home-a', name: 'Home A', icon: <span />, tag: 'home' },
{ id: 'loose', name: 'Loose', icon: <span />, tag: undefined }
]}
variant="assistant"
onAdd={vi.fn()}
onReorder={vi.fn()}
onSelect={vi.fn()}
/>
)
// Pinned section stays on top; non-pinned entities split into tag sections + an untagged section.
expect(screen.getByText('selector.common.pinned_title')).toBeInTheDocument()
expect(screen.getByText('work')).toBeInTheDocument()
expect(screen.getByText('home')).toBeInTheDocument()
expect(screen.getByText('assistants.tags.untagged')).toBeInTheDocument()
expect(
Array.from(
screen.getByRole('listbox', { name: 'Assistants list' }).querySelectorAll('button[aria-expanded]')
).map((header) => header.textContent)
).toEqual(['selector.common.pinned_title', 'assistants.tags.untagged', 'work', 'home'])
// A pinned entity stays under the pinned section even though it carries a tag — its tag must not
// spawn a second "work" header.
expect(screen.getAllByText('work')).toHaveLength(1)
// The flat default "Assistants" header never appears while grouping by tag.
expect(screen.queryByText('Assistants')).not.toBeInTheDocument()
expect(screen.getByTestId('work-a-icon')).toBeInTheDocument()
expect(screen.getByRole('listbox', { name: 'Assistants list' })).toHaveAttribute('data-draggable', 'false')
})
it('keeps a real tag named like the untagged sentinel separate from untagged entities', () => {
render(
<ResourceEntityRail
addLabel="New"
ariaLabel="Assistants list"
defaultGroupLabel="Assistants"
groupByTag
items={[
{ id: 'sentinel-tagged', name: 'Sentinel Tagged', icon: <span />, tag: '__untagged__' },
{ id: 'loose', name: 'Loose', icon: <span />, tag: undefined }
]}
variant="assistant"
onAdd={vi.fn()}
onSelect={vi.fn()}
/>
)
expect(screen.getByText('__untagged__')).toBeInTheDocument()
expect(screen.getByText('assistants.tags.untagged')).toBeInTheDocument()
})
it('ignores entity tags when groupByTag is off', () => {
render(
<ResourceEntityRail
addLabel="New"
ariaLabel="Assistants list"
defaultGroupLabel="Assistants"
items={[
{ id: 'work-a', name: 'Work A', icon: <span />, tag: 'work' },
{ id: 'home-a', name: 'Home A', icon: <span />, tag: 'home' }
]}
variant="assistant"
onAdd={vi.fn()}
onReorder={vi.fn()}
onSelect={vi.fn()}
/>
)
expect(screen.queryByText('work')).not.toBeInTheDocument()
expect(screen.queryByText('home')).not.toBeInTheDocument()
expect(screen.queryByText('assistants.tags.untagged')).not.toBeInTheDocument()
})
it('renders a flat list with no section header when nothing is pinned', () => {
render(
<ResourceEntityRail

View File

@@ -75,7 +75,7 @@ import {
ComposerToolMenuControls
} from './shared/ComposerControlScaffolding'
import { type AddNewTopicPayload, emptyActions, type ProviderActionHandlers } from './shared/composerProviderActions'
import { buildComposerQueuedPayload } from './shared/composerQueuedPayload'
import { buildComposerQueuedPayload, hasUnsyncedComposerAttachments } from './shared/composerQueuedPayload'
import { useComposerQuoteInsertion } from './shared/composerQuote'
import { useComposerFileCapabilities } from './shared/useComposerFileCapabilities'
import { useLatest } from './shared/useLatest'
@@ -798,6 +798,8 @@ const ChatComposerInner = ({
async (draft: ComposerSerializedDraft) => {
const tokenIds = getComposerTokenIds(draft.tokens)
const payloadFiles = files.filter((file) => tokenIds.has(chatComposerTokenId.file(file)))
if (hasUnsyncedComposerAttachments(files, payloadFiles)) return null
const originalFilePartsByTokenId = editingOriginalFilePartsByTokenIdRef.current
const newFiles = payloadFiles.filter((file) => !originalFilePartsByTokenId.has(chatComposerTokenId.file(file)))
@@ -840,6 +842,8 @@ const ChatComposerInner = ({
}
const editedParts = await buildEditedMessageParts(draft)
if (!editedParts) return
try {
await chatWrite.forkAndResend(editingMessageForCurrentTopic.message.id, editedParts)
restoreSavedDraft()

View File

@@ -1256,6 +1256,46 @@ describe('AgentComposer', () => {
expect(mocks.setFiles).toHaveBeenLastCalledWith([])
})
it('does not send while only some attached file tokens are reflected in the editor', async () => {
const secondFile = {
id: 'file-2',
fileTokenSourceId: 'source-file-2',
name: 'summary.md',
origin_name: 'summary.md',
path: '/tmp/summary.md'
} as FileMetadata
mocks.files = [file, secondFile]
mocks.draftTokens = [
{
id: `file:${file.fileTokenSourceId}`,
kind: 'file',
label: file.name,
payload: file,
index: 0,
textOffset: mocks.draftText.length
} as ComposerSerializedToken
]
render(
<AgentComposer
agentId="agent-1"
sessionId="session-1"
sendMessage={mocks.sendMessage}
stop={mocks.stop}
isStreaming={false}
/>
)
await act(async () => {
fireEvent.click(screen.getByText('send'))
await new Promise((resolve) => setTimeout(resolve, 0))
})
expect(mocks.sendMessage).not.toHaveBeenCalled()
expect(mocks.setFiles).not.toHaveBeenCalledWith([])
expect(mocks.files).toEqual([file, secondFile])
})
it('blocks sends while the parent session is switching', () => {
render(
<AgentComposer
@@ -1340,13 +1380,20 @@ describe('AgentComposer', () => {
it('restores the current draft, files, and skill tokens when sending a new agent message fails', async () => {
mocks.availableSkills = [pdfSkill]
mocks.draftText = 'draft message'
mocks.draftTokens = [
{
...pdfSkillToken,
index: 0,
textOffset: 0
}
]
const skillToken = {
...pdfSkillToken,
index: 0,
textOffset: 0
}
const fileToken = {
id: `file:${file.fileTokenSourceId}`,
kind: 'file',
label: file.name,
payload: file,
index: 1,
textOffset: mocks.draftText.length
} as ComposerSerializedToken
mocks.draftTokens = [skillToken, fileToken]
mocks.files = [file]
mocks.sendMessage.mockRejectedValueOnce(new Error('send failed'))
@@ -1365,7 +1412,7 @@ describe('AgentComposer', () => {
})
await waitFor(() => {
expect(mocks.surfaceProps?.draftTokens).toEqual(mocks.draftTokens)
expect(mocks.surfaceProps?.draftTokens).toEqual([skillToken])
})
fireEvent.click(screen.getByText('send'))
@@ -1378,24 +1425,12 @@ describe('AgentComposer', () => {
expect(mocks.setFiles).toHaveBeenCalledWith([])
expect(mocks.setFiles).toHaveBeenLastCalledWith([file])
expect(mocks.surfaceProps?.text).toBe('draft message')
expect(mocks.surfaceProps?.draftTokens).toEqual([
{
...pdfSkillToken,
index: 0,
textOffset: 0
}
])
expect(mocks.surfaceProps?.draftTokens).toEqual([skillToken])
expect(cacheService.setCasual).toHaveBeenLastCalledWith(
'agent-session-draft-agent-1',
{
text: 'draft message',
tokens: [
{
...pdfSkillToken,
index: 0,
textOffset: 0
}
]
tokens: [skillToken]
},
86400000
)

View File

@@ -1097,6 +1097,64 @@ describe('ChatComposer', () => {
expect(mocks.surfaceProps?.sendDisabled).toBe(false)
})
it('does not submit a file-only draft before the file token is reflected in the editor', async () => {
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
const onSend = vi.fn().mockResolvedValue(undefined)
render(<ChatComposer topic={topic} onSend={onSend} />)
await act(async () => {
await mocks.surfaceProps?.onSendDraft({ text: '', tokens: [] })
})
expect(onSend).not.toHaveBeenCalled()
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
})
it('does not submit a text draft before a newly attached file token is reflected in the editor', async () => {
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
const onSend = vi.fn().mockResolvedValue(undefined)
render(<ChatComposer topic={topic} onSend={onSend} />)
await act(async () => {
await mocks.surfaceProps?.onSendDraft({ text: 'summarize this', tokens: [] })
})
expect(onSend).not.toHaveBeenCalled()
expect(mocks.files).toHaveLength(1)
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
})
it('does not submit a text draft while only some attached file tokens are reflected in the editor', async () => {
const syncedFile = { fileTokenSourceId: 'src-1', name: 'first.pdf', path: '/tmp/first.pdf' } as any
const unsyncedFile = { fileTokenSourceId: 'src-2', name: 'second.pdf', path: '/tmp/second.pdf' } as any
mocks.files = [syncedFile, unsyncedFile]
const onSend = vi.fn().mockResolvedValue(undefined)
render(<ChatComposer topic={topic} onSend={onSend} />)
await act(async () => {
await mocks.surfaceProps?.onSendDraft({
text: 'summarize these',
tokens: [
{
id: 'file:src-1',
kind: 'file',
label: 'first.pdf',
payload: syncedFile,
index: 0,
textOffset: 0
} as ComposerSerializedToken
]
})
})
expect(onSend).not.toHaveBeenCalled()
expect(mocks.files).toEqual([syncedFile, unsyncedFile])
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
})
it('keeps a steered follow-up in the dock and toasts when its manual send fails', async () => {
mocks.topicPending = true
const onSend = vi.fn().mockResolvedValue(undefined)
@@ -2293,6 +2351,98 @@ describe('ChatComposer', () => {
await waitFor(() => expect(mocks.surfaceProps?.editingState).toBeUndefined())
})
it('does not fork and resend an edited file-only draft before the file token is reflected in the editor', async () => {
const editMessage = vi.fn().mockResolvedValue(undefined)
const resend = vi.fn().mockResolvedValue(undefined)
const forkAndResend = vi.fn().mockResolvedValue(undefined)
mocks.chatWrite = { pause: vi.fn(), editMessage, resend, forkAndResend }
const message = {
id: 'message-1',
role: 'user',
topicId: topic.id,
createdAt: '2026-01-01T00:00:00.000Z',
status: 'success'
} as const
render(
<MessageEditingProvider>
<StartEditingOnMount message={message as any} parts={[{ type: 'text', text: 'old' }] as any} />
<ChatComposer topic={topic} onSend={vi.fn()} />
</MessageEditingProvider>
)
await waitFor(() => expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1'))
act(() => {
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
mocks.surfaceProps?.onTextChange('')
})
await waitFor(() => expect(mocks.surfaceProps?.text).toBe(''))
await act(async () => {
await mocks.surfaceProps?.onSendDraft({ text: '', tokens: [] })
})
expect(forkAndResend).not.toHaveBeenCalled()
expect(editMessage).not.toHaveBeenCalled()
expect(resend).not.toHaveBeenCalled()
expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1')
expect(mocks.toastError).not.toHaveBeenCalledWith('message.error.operation_unavailable')
})
it('does not fork and resend an edited draft while only some attached file tokens are reflected in the editor', async () => {
const editMessage = vi.fn().mockResolvedValue(undefined)
const resend = vi.fn().mockResolvedValue(undefined)
const forkAndResend = vi.fn().mockResolvedValue(undefined)
mocks.chatWrite = { pause: vi.fn(), editMessage, resend, forkAndResend }
const message = {
id: 'message-1',
role: 'user',
topicId: topic.id,
createdAt: '2026-01-01T00:00:00.000Z',
status: 'success'
} as const
const syncedFile = { fileTokenSourceId: 'src-1', name: 'first.pdf', path: '/tmp/first.pdf' } as any
const unsyncedFile = { fileTokenSourceId: 'src-2', name: 'second.pdf', path: '/tmp/second.pdf' } as any
render(
<MessageEditingProvider>
<StartEditingOnMount message={message as any} parts={[{ type: 'text', text: 'old' }] as any} />
<ChatComposer topic={topic} onSend={vi.fn()} />
</MessageEditingProvider>
)
await waitFor(() => expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1'))
act(() => {
mocks.files = [syncedFile, unsyncedFile]
mocks.surfaceProps?.onTextChange('')
})
await waitFor(() => expect(mocks.surfaceProps?.text).toBe(''))
await act(async () => {
await mocks.surfaceProps?.onSendDraft({
text: '',
tokens: [
{
id: 'file:src-1',
kind: 'file',
label: 'first.pdf',
payload: syncedFile,
index: 0,
textOffset: 0
} as ComposerSerializedToken
]
})
})
expect(forkAndResend).not.toHaveBeenCalled()
expect(editMessage).not.toHaveBeenCalled()
expect(resend).not.toHaveBeenCalled()
expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1')
expect(mocks.toastError).not.toHaveBeenCalledWith('message.error.operation_unavailable')
})
it('keeps editing when the edited message fork and resend fails', async () => {
const editMessage = vi.fn().mockResolvedValue(undefined)
const resend = vi.fn().mockResolvedValue(undefined)

View File

@@ -39,17 +39,42 @@ describe('buildComposerQueuedPayload', () => {
expect(result?.userMessageParts).toEqual([{ type: 'text', text: '' }])
})
it('attaches only files still present as draft tokens', () => {
const kept = file('a')
const removed = file('b')
it('returns null for a file-only draft whose file token has not reached the editor draft yet', () => {
const result = buildComposerQueuedPayload(draft('', []), { files: [file('a')], fileTokenId })
expect(result).toBeNull()
})
it('returns null for a text draft whose file token has not reached the editor draft yet', () => {
const result = buildComposerQueuedPayload(draft('summarize this', []), { files: [file('a')], fileTokenId })
expect(result).toBeNull()
})
it('returns null when only some current file tokens have reached the editor draft', () => {
const synced = file('a')
const unsynced = file('b')
const result = buildComposerQueuedPayload(draft('hi', ['file:a']), {
files: [kept, removed],
files: [synced, unsynced],
fileTokenId,
requireText: true
})
expect(result?.attachments).toEqual([kept])
expect(result).toBeNull()
})
it('attaches files when every current file is present as a draft token', () => {
const first = file('a')
const second = file('b')
const result = buildComposerQueuedPayload(draft('hi', ['file:a', 'file:b']), {
files: [first, second],
fileTokenId,
requireText: true
})
expect(result?.attachments).toEqual([first, second])
expect(result?.userMessageParts).toEqual([{ type: 'text', text: 'hi' }])
})

View File

@@ -31,10 +31,11 @@ export function buildComposerQueuedPayload(
{ files, fileTokenId, requireText = false, extra }: BuildComposerQueuedPayloadOptions
): ComposerQueuedMessagePayload | null {
const text = draft.text.trim()
if (requireText ? !text : !text && files.length === 0) return null
const tokenIds = getComposerTokenIds(draft.tokens)
const attachedFiles = files.filter((file) => tokenIds.has(fileTokenId(file)))
if (hasUnsyncedComposerAttachments(files, attachedFiles)) return null
if (requireText ? !text : !text && attachedFiles.length === 0) return null
const userMessageParts = createComposerUserMessageParts(draft)
return {
@@ -44,3 +45,7 @@ export function buildComposerQueuedPayload(
...extra?.(tokenIds, attachedFiles)
}
}
export function hasUnsyncedComposerAttachments(files: ComposerAttachment[], attachedFiles: ComposerAttachment[]) {
return attachedFiles.length !== files.length
}

View File

@@ -27,8 +27,8 @@ const AssistantEditDialog = lazy(() =>
/**
* Row shape the selector operates on — derived from the Assistant DTO. `selectionType: 'item'`
* returns values of this shape (not the raw Assistant) so the selector never leaks DB columns the
* caller didn't ask about. User tag names may be present so the selector can filter by assistant
* tags.
* caller didn't ask about. A user tag name may be present so the selector can filter by assistant
* tag.
*/
export type AssistantSelectorItem = ResourceSelectorShellItem
@@ -118,7 +118,7 @@ export function AssistantSelector(props: AssistantSelectorProps) {
name: a.name,
emoji: a.emoji,
description: a.description,
tags: (a.tags ?? []).map((tag) => tag.name)
tag: a.tags?.[0]?.name
})),
[data]
)
@@ -126,7 +126,8 @@ export function AssistantSelector(props: AssistantSelectorProps) {
const tags = useMemo<ResourceSelectorShellTag[]>(() => {
const byName = new Map<string, string | undefined>()
for (const assistant of data?.items ?? []) {
for (const tag of assistant.tags ?? []) {
const tag = assistant.tags?.[0]
if (tag) {
if (!byName.has(tag.name)) {
byName.set(tag.name, tag.color ?? undefined)
}
@@ -200,7 +201,7 @@ export function AssistantSelector(props: AssistantSelectorProps) {
name: created.name,
emoji: created.emoji,
description: created.description,
tags: (created.tags ?? []).map((tag) => tag.name)
tag: created.tags?.[0]?.name
})
} else {
props.onChange(created.id)

View File

@@ -30,7 +30,7 @@ export type ResourceSelectorShellItem = {
name: string
emoji?: string
description?: string
tags?: string[]
tag?: string
disabled?: boolean
}
@@ -273,7 +273,7 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
)
const [searchValue, setSearchValue] = useState('')
const [selectedTagIds, setSelectedTagIds] = useState<string[]>([])
const [selectedTagName, setSelectedTagName] = useState<string | null>(null)
const listboxId = useId()
const listRef = useRef<HTMLDivElement>(null)
const searchInputRef = useRef<HTMLInputElement>(null)
@@ -320,9 +320,8 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
const { pinnedItems, unpinnedItems } = useMemo(() => {
let filtered = items
if (selectedTagIds.length > 0) {
const wanted = new Set(selectedTagIds)
filtered = filtered.filter((item) => item.tags?.some((tag) => wanted.has(tag)))
if (selectedTagName) {
filtered = filtered.filter((item) => item.tag === selectedTagName)
}
const query = searchValue.trim().toLowerCase()
@@ -338,7 +337,7 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
const unpinned = filtered.filter((item) => !pinnedSet.has(item.id))
const pinnedOrdered = pinnedIds.map((id) => pinned.find((item) => item.id === id)).filter(Boolean) as T[]
return { pinnedItems: pinnedOrdered, unpinnedItems: unpinned }
}, [items, pinnedIds, pinnedSet, searchValue, selectedTagIds])
}, [items, pinnedIds, pinnedSet, searchValue, selectedTagName])
const sections = useMemo<ResourceSelectorSection<T>[]>(() => {
const nextSections: ResourceSelectorSection<T>[] = []
@@ -569,18 +568,14 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
<>
{labels.tagFilter ? <span className="mr-1 text-[10px] text-muted-foreground">{labels.tagFilter}</span> : null}
{tagOptions.map((tag) => {
const active = selectedTagIds.includes(tag.name)
const active = selectedTagName === tag.name
return (
<ResourceTagChip
key={tag.name}
tag={tag.name}
color={tag.color}
active={active}
onClick={() =>
setSelectedTagIds((prev) =>
prev.includes(tag.name) ? prev.filter((value) => value !== tag.name) : [...prev, tag.name]
)
}
onClick={() => setSelectedTagName((prev) => (prev === tag.name ? null : tag.name))}
/>
)
})}
@@ -607,16 +602,13 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
<span className="flex size-5 shrink-0 items-center justify-center">{fallbackIcon}</span>
) : null
const trailing =
item.tags && item.tags.length > 0 ? (
<div
className="ml-2 flex h-4 max-w-[48%] shrink-0 items-center justify-end gap-1 overflow-hidden"
data-resource-selector-tags={item.id}>
{item.tags.map((tag) => (
<ResourceTagChip key={`${item.id}-${tag}`} tag={tag} color={tagColorByName.get(tag)} />
))}
</div>
) : null
const trailing = item.tag ? (
<div
className="ml-2 flex h-4 max-w-[48%] shrink-0 items-center justify-end gap-1 overflow-hidden"
data-resource-selector-tags={item.id}>
<ResourceTagChip tag={item.tag} color={tagColorByName.get(item.tag)} />
</div>
) : null
return (
<div key={item.id} className="py-0.5">

View File

@@ -92,7 +92,7 @@ vi.mock('react-i18next', async (importOriginal) => {
'library.config.basic.model_pick': 'Pick model',
'library.config.basic.model_not_found': 'Model {{id}} is unavailable.',
'library.config.basic.tag_empty': 'No tags',
'library.config.basic.tag_placeholder': 'Select tags',
'library.config.basic.tag_placeholder': 'Select tag',
'library.config.basic.tag_search': 'Search tags',
'library.config.prompt.label': 'Prompt',
'library.config.prompt.placeholder': 'Tell this assistant how to respond',

View File

@@ -636,7 +636,7 @@ describe('ResourceSelectorShell', () => {
describe('edit button', () => {
it('places edit and pin together in the row action area', () => {
const taggedItems: Item[] = [{ ...ITEMS[0], tags: ['Cherry', 'DEV'] }, ...ITEMS.slice(1)]
const taggedItems: Item[] = [{ ...ITEMS[0], tag: 'Cherry' }, ...ITEMS.slice(1)]
render(
<ResourceSelectorShell
@@ -891,5 +891,35 @@ describe('ResourceSelectorShell', () => {
expect(screen.queryByRole('option', { name: /Alpha/ })).toBeInTheDocument()
expect(screen.queryByRole('option', { name: /Beta/ })).not.toBeInTheDocument()
})
it('uses a single active tag filter at a time', () => {
render(
<ResourceSelectorShell
trigger={<button type="button">Open</button>}
items={[
{ ...ITEMS[0], tag: 'Cherry' },
{ ...ITEMS[1], tag: 'DEV' },
{ ...ITEMS[2], tag: 'Cherry' }
]}
tags={['Cherry', 'DEV']}
pinnedIds={[]}
onTogglePin={vi.fn()}
labels={LABELS}
value={null}
onChange={vi.fn()}
/>
)
openPopover()
fireEvent.click(screen.getByRole('button', { name: 'Cherry' }))
expect(screen.queryByRole('option', { name: /Alpha/ })).toBeInTheDocument()
expect(screen.queryByRole('option', { name: /Gamma/ })).toBeInTheDocument()
expect(screen.queryByRole('option', { name: /Beta/ })).not.toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'DEV' }))
expect(screen.queryByRole('option', { name: /Beta/ })).toBeInTheDocument()
expect(screen.queryByRole('option', { name: /Alpha/ })).not.toBeInTheDocument()
expect(screen.queryByRole('option', { name: /Gamma/ })).not.toBeInTheDocument()
})
})
})

View File

@@ -295,8 +295,9 @@ vi.mock('react-i18next', async (importOriginal) => {
'library.config.basic.model_not_found': 'Model {{id}} is unavailable.',
'library.config.basic.precise': 'Precise',
'library.config.basic.stream_output': 'Stream output',
'library.config.basic.tags': 'Tags',
'library.config.basic.tag_empty': 'No tags',
'library.config.basic.tag_placeholder': 'Select tags',
'library.config.basic.tag_placeholder': 'Select tag',
'library.config.basic.tag_search': 'Search tags',
'library.config.basic.mcp_mode': 'MCP Mode',
'library.config.basic.temperature': 'Temperature',
@@ -558,13 +559,10 @@ async function expectVariablesHelpOnHover() {
await waitFor(() => expect(screen.getAllByText('{{date}}').length).toBeGreaterThan(0))
}
function openTagCombobox() {
const removeTagButton = screen.getByRole('button', { name: 'Remove work' })
const combobox = removeTagButton.closest('[role="combobox"]')
if (!combobox) {
throw new Error('Tag combobox trigger not found')
}
fireEvent.click(combobox)
function openTagSelect() {
const select = screen.getByRole('combobox', { name: 'Tags' })
fireEvent.pointerDown(select)
fireEvent.click(select)
}
describe('edit dialogs', () => {
@@ -616,57 +614,45 @@ describe('edit dialogs', () => {
})
it('submits assistant tag changes through ensureTags', async () => {
ensureTagsMock.mockResolvedValueOnce([
{ id: 'tag-work', name: 'work', color: '#8b5cf6' },
{ id: 'tag-personal', name: 'personal', color: '#10b981' }
])
ensureTagsMock.mockResolvedValueOnce([{ id: 'tag-personal', name: 'personal', color: '#10b981' }])
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
openTagCombobox()
fireEvent.click(await screen.findByText('personal'))
openTagSelect()
fireEvent.click(await screen.findByRole('option', { name: 'personal' }))
fireEvent.click(screen.getByRole('button', { name: 'Save' }))
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['work', 'personal']))
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['personal']))
expect(updateAssistantMock).toHaveBeenCalledWith({
body: expect.objectContaining({
tagIds: ['tag-work', 'tag-personal']
tagIds: ['tag-personal']
})
})
})
it('creates and binds a new tag typed in assistant editing', async () => {
ensureTagsMock.mockResolvedValueOnce([
{ id: 'tag-work', name: 'work', color: '#8b5cf6' },
{ id: 'tag-new', name: 'new-tag', color: '#10b981' }
])
it('clears the assistant tag from the single-select tag field', async () => {
ensureTagsMock.mockResolvedValueOnce([])
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
openTagCombobox()
fireEvent.change(screen.getByPlaceholderText('Search tags'), { target: { value: 'new-tag' } })
fireEvent.click(await screen.findByText('new-tag'))
const clearButton = screen.getByRole('button', { name: 'Tags Clear' })
expect(clearButton).toHaveClass('focus-visible:pointer-events-auto', 'focus-visible:opacity-100')
fireEvent.click(clearButton)
fireEvent.click(screen.getByRole('button', { name: 'Save' }))
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['work', 'new-tag']))
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith([]))
expect(updateAssistantMock).toHaveBeenCalledWith({
body: expect.objectContaining({
tagIds: ['tag-work', 'tag-new']
tagIds: []
})
})
})
it('does not expose typed tag names beyond the server-side max length', async () => {
it('limits assistant tag editing to existing tags', async () => {
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
// Open the combobox popover (CommandInput only mounts once the trigger is active)
openTagCombobox()
const atLimit = 'y'.repeat(64) // TagNameSchema.max(64)
const tooLong = 'x'.repeat(65) // one over the limit — server would reject
const searchInput = screen.getByPlaceholderText('Search tags')
fireEvent.change(searchInput, { target: { value: tooLong } })
expect(screen.queryByText(tooLong)).not.toBeInTheDocument()
fireEvent.change(searchInput, { target: { value: atLimit } })
expect(await screen.findByText(atLimit)).toBeInTheDocument()
openTagSelect()
expect(screen.queryByPlaceholderText('Search tags')).not.toBeInTheDocument()
expect(screen.queryByRole('option', { name: 'No tag' })).not.toBeInTheDocument()
expect(screen.queryByText('new-tag')).not.toBeInTheDocument()
})
it('submits agent instructions and model changes as a PATCH', async () => {

View File

@@ -1,53 +1,80 @@
import { Combobox, type ComboboxOption } from '@cherrystudio/ui'
import { TagNameSchema } from '@shared/data/types/tag'
import { Button, Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@cherrystudio/ui'
import { cn } from '@renderer/utils/style'
import { X } from 'lucide-react'
import type { FC } from 'react'
import { useMemo, useState } from 'react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
value: string[]
onChange: (tags: string[]) => void
value: string | null
onChange: (tag: string | null) => void
allTagNames: string[]
disabled?: boolean
portalContainer?: HTMLElement | null
}
const TAG_SELECT_VALUE_PREFIX = 'tag:'
function encodeTagSelectValue(name: string) {
return `${TAG_SELECT_VALUE_PREFIX}${name}`
}
function decodeTagSelectValue(value: string) {
if (!value.startsWith(TAG_SELECT_VALUE_PREFIX)) return null
return value.slice(TAG_SELECT_VALUE_PREFIX.length)
}
export const TagSelector: FC<Props> = ({ value, onChange, allTagNames, disabled, portalContainer }) => {
const { t } = useTranslation()
const [search, setSearch] = useState('')
// `value` may contain names not present in `/tags` yet, for example while a
// caller waits for SWR refresh. Keep selected names visible in the options.
const tagOptions = useMemo<ComboboxOption[]>(() => {
const trimmedSearch = search.trim()
const names = new Set([...allTagNames, ...value])
// Mirror the server-side TagNameSchema (z.string().trim().min(1).max(64))
// so a user cannot select a name the create endpoint would reject.
if (trimmedSearch && TagNameSchema.safeParse(trimmedSearch).success) {
names.add(trimmedSearch)
}
// `value` may be a name not present in `/tags` yet, for example while a
// caller waits for SWR refresh. Keep the selected name visible in the options.
const tagNames = useMemo(() => {
const names = new Set(allTagNames)
if (value) names.add(value)
const sortedNames = Array.from(names)
sortedNames.sort((a, b) => a.localeCompare(b, 'zh'))
return sortedNames.map((name) => ({
value: name,
label: name
}))
}, [allTagNames, search, value])
return sortedNames
}, [allTagNames, value])
return (
<Combobox
multiple
size="sm"
disabled={disabled}
options={tagOptions}
value={value}
onChange={(v) => onChange(Array.isArray(v) ? v : v ? [v] : [])}
onSearch={setSearch}
placeholder={t('library.config.basic.tag_placeholder')}
searchPlaceholder={t('library.config.basic.tag_search')}
emptyText={t('library.config.basic.tag_empty')}
portalContainer={portalContainer ?? undefined}
/>
<div className="group/tag-select relative flex w-full min-w-0 items-center">
<Select
disabled={disabled}
value={value ? encodeTagSelectValue(value) : ''}
onValueChange={(selectedValue) => onChange(decodeTagSelectValue(selectedValue))}>
<SelectTrigger
size="sm"
className={cn(
'w-full',
value &&
'[&_svg]:transition-opacity group-focus-within/tag-select:[&_svg]:opacity-0 group-hover/tag-select:[&_svg]:opacity-0'
)}
aria-label={t('library.config.basic.tags')}>
<SelectValue placeholder={t('library.config.basic.tag_placeholder')} />
</SelectTrigger>
<SelectContent portalContainer={portalContainer ?? undefined}>
{tagNames.map((name) => (
<SelectItem key={name} value={encodeTagSelectValue(name)}>
{name}
</SelectItem>
))}
</SelectContent>
</Select>
{value && !disabled ? (
<Button
type="button"
variant="ghost"
aria-label={`${t('library.config.basic.tags')} ${t('common.clear')}`}
onClick={(event) => {
event.stopPropagation()
onChange(null)
}}
className="-translate-y-1/2 pointer-events-none absolute top-1/2 right-2.5 flex size-5 min-h-0 shrink-0 items-center justify-center rounded-full bg-transparent p-0 text-muted-foreground/70 opacity-0 shadow-none transition-[background-color,color,opacity] hover:bg-muted hover:text-foreground focus-visible:pointer-events-auto focus-visible:opacity-100 active:bg-muted group-focus-within/tag-select:pointer-events-auto group-focus-within/tag-select:opacity-100 group-hover/tag-select:pointer-events-auto group-hover/tag-select:opacity-100">
<X size={12} />
</Button>
) : null}
</div>
)
}

View File

@@ -63,7 +63,7 @@ type AssistantEditFormValues = {
name: string
description: string
modelId: UniqueModelId | null
tags: string[]
tagName: string | null
prompt: string
temperature: number
enableTemperature: boolean
@@ -99,7 +99,7 @@ function defaultValuesForAssistant(resource: AssistantEditDialogResource): Assis
name: form.name,
description: form.description,
modelId: form.modelId ?? null,
tags: form.tags,
tagName: form.tagName,
prompt: form.prompt,
temperature: form.temperature,
enableTemperature: form.enableTemperature,
@@ -132,7 +132,7 @@ function buildAssistantFormState(baseline: AssistantFormState, values: Assistant
name: values.name,
description: values.description,
modelId: values.modelId,
tags: values.tags,
tagName: values.tagName,
prompt: values.prompt,
temperature: values.temperature,
enableTemperature: values.enableTemperature,
@@ -361,7 +361,7 @@ function AssistantBasicFields({
</div>
<FormField
control={form.control}
name="tags"
name="tagName"
render={({ field }) => (
<FormItem className="min-w-0">
<FormLabel>{t('library.config.basic.tags')}</FormLabel>

View File

@@ -69,9 +69,9 @@ describe('initialAssistantFormState', () => {
})
})
it('extracts tag names from embedded tag rows', () => {
it('extracts a single tag name from embedded tag rows', () => {
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00'), tag('t2', 'beta', '#0f0')] })
expect(initialAssistantFormState(assistant).tags).toEqual(['alpha', 'beta'])
expect(initialAssistantFormState(assistant).tagName).toBe('alpha')
})
})
@@ -128,22 +128,24 @@ describe('diffAssistantUpdate', () => {
expect(result?.dto.settings).toMatchObject({ reasoning_effort: 'high' })
})
it('flags tag changes and passes form.tags through as tagNames', () => {
it('flags tag changes and passes one form tag through as tagNames', () => {
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00')] })
const baseline = initialAssistantFormState(assistant)
const form = { ...baseline, tags: ['alpha', 'new'] }
const form = { ...baseline, tagName: 'new' }
const result = diffAssistantUpdate(form, baseline, assistant)
expect(result?.tagsChanged).toBe(true)
expect(result?.tagNames).toEqual(['alpha', 'new'])
expect(result?.tagNames).toEqual(['new'])
})
it('treats tag reorder (same set) as unchanged', () => {
it('flags clearing the assistant tag', () => {
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00'), tag('t2', 'beta', '#0f0')] })
const baseline = initialAssistantFormState(assistant)
const form = { ...baseline, tags: ['beta', 'alpha'] }
const form = { ...baseline, tagName: null }
expect(diffAssistantUpdate(form, baseline, assistant)).toBeNull()
const result = diffAssistantUpdate(form, baseline, assistant)
expect(result?.tagsChanged).toBe(true)
expect(result?.tagNames).toEqual([])
})
it('emits knowledgeBaseIds only when the set changes, ignoring order', () => {
@@ -187,12 +189,12 @@ describe('diffAssistantSaveIntent', () => {
it('wraps update diffs for the edit dialog save handler', () => {
const assistant = createAssistant({ tags: [tag('t1', 'alpha')] })
const baseline = initialAssistantFormState(assistant)
const form = { ...baseline, tags: ['alpha', 'beta'] }
const form = { ...baseline, tagName: 'beta' }
expect(diffAssistantSaveIntent(form, baseline, assistant)).toEqual({
kind: 'update',
payload: {},
tagNames: ['alpha', 'beta'],
tagNames: ['beta'],
tagsChanged: true
})
})

View File

@@ -21,9 +21,9 @@ const UI_DEFAULT_MAX_TOOL_CALLS = 20
* Flat form state for the Assistant edit dialog. Every editable field lives
* here so the dialog commits in a single PATCH.
*
* `tags` stores user-facing names, not ids — tag-id resolution happens
* at save time via `ensureTags` so the user can freely type new tags without
* paying a network round-trip per keystroke.
* `tagName` stores one user-facing name, not an id — tag-id resolution happens
* at save time via `ensureTags`, keeping the form state independent from
* backend tag ids.
*/
export interface AssistantFormState {
// columns
@@ -46,11 +46,15 @@ export interface AssistantFormState {
customParameters: CustomParameter[]
mcpMode: AssistantSettings['mcpMode']
// relations
tags: string[]
tagName: string | null
knowledgeBaseIds: string[]
mcpServerIds: string[]
}
function normalizeAssistantTagName(tags: readonly string[]): string | null {
return tags[0] ?? null
}
function buildAssistantSettingsFromForm(
form: AssistantFormState,
baseSettings: AssistantSettings = DEFAULT_ASSISTANT_SETTINGS
@@ -90,7 +94,7 @@ export function initialAssistantFormState(assistant: Assistant): AssistantFormSt
enableMaxToolCalls: settings.enableMaxToolCalls ?? true,
customParameters: settings.customParameters ?? [],
mcpMode: settings.mcpMode ?? 'auto',
tags: (assistant.tags ?? []).map((t) => t.name),
tagName: normalizeAssistantTagName((assistant.tags ?? []).map((t) => t.name)),
knowledgeBaseIds: assistant.knowledgeBaseIds ?? [],
mcpServerIds: assistant.mcpServerIds ?? []
}
@@ -160,7 +164,7 @@ export function diffAssistantUpdate(
baseline.mcpMode !== form.mcpMode ||
customParametersChanged
const tagsChanged = !sameStringSet(baseline.tags, form.tags)
const tagsChanged = baseline.tagName !== form.tagName
const knowledgeBaseIdsChanged = !sameIdSet(baseline.knowledgeBaseIds, form.knowledgeBaseIds)
const mcpServerIdsChanged = !sameIdSet(baseline.mcpServerIds, form.mcpServerIds)
@@ -183,7 +187,7 @@ export function diffAssistantUpdate(
...(mcpServerIdsChanged ? { mcpServerIds: form.mcpServerIds } : {})
}
return { dto, tagsChanged, tagNames: form.tags }
return { dto, tagsChanged, tagNames: form.tagName ? [form.tagName] : [] }
}
export function diffAssistantSaveIntent(
@@ -208,9 +212,3 @@ function sameIdSet(a: readonly string[], b: readonly string[]): boolean {
const set = new Set(a)
return b.every((id) => set.has(id))
}
function sameStringSet(a: readonly string[], b: readonly string[]): boolean {
if (a.length !== b.length) return false
const set = new Set(a)
return b.every((v) => set.has(v))
}

View File

@@ -1333,12 +1333,14 @@
"add": "Add Tag",
"delete": "Delete Tag",
"deleteConfirm": "Are you sure to delete this tag?",
"group_by": "Group by tag",
"manage": "Tag Management",
"modify": "Modify Tag",
"none": "No tags",
"settings": {
"title": "Tag Settings"
},
"ungroup": "Turn off tag grouping",
"untagged": "Untagged"
},
"title": "Assistants",
@@ -3154,9 +3156,9 @@
"stream_output": "Stream output",
"tag_empty": "No tags available",
"tag_hint": "To add a new tag, use the \"+ Tag\" entry in the library top bar",
"tag_placeholder": "Select tags",
"tag_placeholder": "Select tag",
"tag_search": "Search tags",
"tags": "Tags",
"tags": "Tag",
"temperature": "Temperature",
"title": "Basic settings",
"top_p": "Top-P",

View File

@@ -1333,12 +1333,14 @@
"add": "添加标签",
"delete": "删除标签",
"deleteConfirm": "确定要删除这个标签吗?",
"group_by": "按照标签分组",
"manage": "标签管理",
"modify": "修改标签",
"none": "暂无标签",
"settings": {
"title": "标签设置"
},
"ungroup": "关闭标签分组",
"untagged": "未分组"
},
"title": "助手",

View File

@@ -1333,12 +1333,14 @@
"add": "新增標籤",
"delete": "刪除標籤",
"deleteConfirm": "確定要刪除這個標籤嗎?",
"group_by": "依標籤分組",
"manage": "標籤管理",
"modify": "修改標籤",
"none": "暫無標籤",
"settings": {
"title": "標籤設定"
},
"ungroup": "關閉標籤分組",
"untagged": "未分組"
},
"title": "助手",

View File

@@ -65,8 +65,10 @@ function buildTags(resources: ResourceItem[], backendTags: Tag[], filterType?: R
for (const tag of r.raw.tags ?? []) {
if (!backendTagByName.has(tag.name)) backendTagByName.set(tag.name, tag)
}
if (r.tag) {
tagMap.set(r.tag, (tagMap.get(r.tag) || 0) + 1)
}
}
r.tags.forEach((t) => tagMap.set(t, (tagMap.get(t) || 0) + 1))
})
return Array.from(tagMap.entries())
.sort((a, b) => b[1] - a[1])
@@ -155,7 +157,6 @@ export default function LibraryPage() {
[tagList.tags]
)
const noop = useCallback(() => {}, [])
const handleClosePromptDialog = useCallback(() => {
setPromptDialog(null)
}, [])
@@ -474,7 +475,6 @@ export default function LibraryPage() {
// rows; binding stays inside card/dialog tag hooks.
await ensureTags([tagName])
}}
onUpdateResourceTags={noop /* binding is executed inside FixedCardMenu via the tag hooks */}
allTagNames={allTagNames}
allTags={tagList.tags}
assistantCatalog={assistantCatalogProp}

View File

@@ -516,7 +516,6 @@ describe('LibraryPage create flow', () => {
name: 'Assistant to duplicate',
description: '',
avatar: '💬',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'assistant-to-duplicate', name: 'Assistant to duplicate', tags: [] }
@@ -630,7 +629,6 @@ describe('LibraryPage create flow', () => {
name: 'Selector Agent',
description: '',
avatar: '',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'agent-from-selector' }
@@ -652,7 +650,6 @@ describe('LibraryPage create flow', () => {
name: 'Selector Agent',
description: '',
avatar: '',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'agent-from-selector' }
@@ -675,7 +672,6 @@ describe('LibraryPage create flow', () => {
name: 'Selector Assistant',
description: '',
avatar: '💬',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'assistant-from-selector' }
@@ -697,7 +693,6 @@ describe('LibraryPage create flow', () => {
name: 'Stale Assistant',
description: '',
avatar: '💬',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'assistant-stale-tags', name: 'Stale Assistant' }
@@ -717,7 +712,6 @@ describe('LibraryPage create flow', () => {
name: 'Grid Prompt',
description: '',
avatar: 'Aa',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'prompt-from-grid' }
@@ -746,7 +740,6 @@ describe('LibraryPage create flow', () => {
name: 'Grid Skill',
description: '',
avatar: 'S',
tags: [],
createdAt: '2024-01-01T00:00:00.000Z',
updatedAt: '2024-01-01T00:00:00.000Z',
raw: { id: 'skill-from-grid', name: 'Grid Skill' }

View File

@@ -78,7 +78,7 @@ describe('useAssistantMutations', () => {
})
})
it('forwards tag ids to the create endpoint when duplicating an assistant', async () => {
it('forwards only one tag id to the create endpoint when duplicating an assistant', async () => {
const created = createAssistant({ id: 'ast-copy', tags: [] })
createTriggerMock.mockResolvedValue(created)
@@ -106,7 +106,7 @@ describe('useAssistantMutations', () => {
settings: source.settings,
mcpServerIds: ['mcp-1'],
knowledgeBaseIds: ['kb-1'],
tagIds: ['tag-1', 'tag-2']
tagIds: ['tag-1']
}
})
})

View File

@@ -66,6 +66,7 @@ export function useAssistantMutations() {
const duplicateAssistant = useCallback(
async (source: Assistant): Promise<Assistant> => {
const duplicateName = t('library.duplicate_name', { name: source.name })
const tagId = source.tags[0]?.id
return createTrigger({
body: {
@@ -77,7 +78,7 @@ export function useAssistantMutations() {
settings: source.settings,
mcpServerIds: source.mcpServerIds,
knowledgeBaseIds: source.knowledgeBaseIds,
tagIds: source.tags.map((tag) => tag.id)
tagIds: tagId ? [tagId] : []
}
})
},

View File

@@ -3,7 +3,7 @@ import type { ResourceType } from '../types'
export interface ResourceListQuery {
/** Free-text match against name OR description (passed through to the API). */
search?: string
/** Union (OR) tag filter — kept if the resource is bound to ANY of these tag ids. */
/** Backend tag-id filter transport shape; current assistant UI passes at most one id. */
tagIds?: string[]
limit?: number
offset?: number

View File

@@ -1,6 +1,5 @@
import {
Button,
Checkbox,
Input,
MenuDivider,
MenuItem,
@@ -12,8 +11,8 @@ import {
} from '@cherrystudio/ui'
import { loggerService } from '@logger'
import { useEnsureTags, useTagList } from '@renderer/hooks/useTags'
import { ChevronDown, Copy, Download, Plus, Tag, Trash2 } from 'lucide-react'
import { useCallback, useRef, useState } from 'react'
import { Check, ChevronDown, Copy, Download, Plus, Tag, Trash2 } from 'lucide-react'
import { type KeyboardEvent, useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useAssistantMutationsById } from '../adapters/assistantAdapter'
@@ -32,7 +31,6 @@ interface ResourceCardMenuProps {
onDuplicate: (r: ResourceItem) => void
onDelete: (r: ResourceItem) => void
onExport: (r: ResourceItem) => void
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
allTagNames: string[]
}
@@ -42,16 +40,19 @@ export function ResourceCardMenu({
onDuplicate,
onDelete,
onExport,
onUpdateResourceTags,
allTagNames
}: ResourceCardMenuProps) {
const { t } = useTranslation()
const [showTagPicker, setShowTagPicker] = useState(false)
const [localTags, setLocalTags] = useState<string[]>(resource.tags)
const [localTag, setLocalTag] = useState<string | null>(() =>
resource.type === 'assistant' ? (resource.tag ?? null) : null
)
const [tagInput, setTagInput] = useState('')
const [bindingError, setBindingError] = useState<string | null>(null)
const [bindingPending, setBindingPending] = useState(false)
const bindingPendingRef = useRef(false)
const tagOptionRefs = useRef<Array<HTMLDivElement | null>>([])
const [activeTagIndex, setActiveTagIndex] = useState(0)
const { ensureTags } = useEnsureTags({ getDefaultColor: getRandomTagColor })
const { updateAssistant } = useAssistantMutationsById(resource.id)
@@ -65,22 +66,28 @@ export function ResourceCardMenu({
const tagList = useTagList()
const colorFor = (name: string): string => tagList.tags.find((tag) => tag.name === name)?.color ?? DEFAULT_TAG_COLOR
const persistTags = useCallback(
async (nextNames: string[], previousNames: string[]) => {
useEffect(() => {
if (!showTagPicker) return
const selectedIndex = localTag ? allTagNames.indexOf(localTag) : -1
setActiveTagIndex(selectedIndex >= 0 ? selectedIndex : 0)
}, [allTagNames, localTag, showTagPicker])
const persistTag = useCallback(
async (nextName: string | null, previousName: string | null) => {
if (!canBindTags) return
if (bindingPendingRef.current) return
bindingPendingRef.current = true
setBindingPending(true)
try {
const nextNames = nextName ? [nextName] : []
const tags = await ensureTags(nextNames)
const tagIds = tags.map((tag) => tag.id)
if (resource.type === 'assistant') {
await updateAssistant({ tagIds })
}
onUpdateResourceTags(resource.id, nextNames)
} catch (e) {
// Roll back optimistic state on failure.
setLocalTags(previousNames)
setLocalTag(previousName)
const message = e instanceof Error ? e.message : t('library.tag_sync_failed')
setBindingError(message)
// The inline error text only renders while the popup is open. Toast +
@@ -96,31 +103,60 @@ export function ResourceCardMenu({
setBindingPending(false)
}
},
[canBindTags, ensureTags, updateAssistant, onUpdateResourceTags, resource.id, resource.type, t]
[canBindTags, ensureTags, updateAssistant, resource.id, resource.type, t]
)
const toggleTag = (tag: string) => {
if (bindingPendingRef.current) return
const prev = localTags
const next = prev.includes(tag) ? prev.filter((item) => item !== tag) : [...prev, tag]
setLocalTags(next)
const prev = localTag
const next = prev === tag ? null : tag
setLocalTag(next)
setBindingError(null)
void persistTags(next, prev)
void persistTag(next, prev)
}
const focusTagOption = (index: number) => {
setActiveTagIndex(index)
tagOptionRefs.current[index]?.focus()
}
const handleTagOptionKeyDown = (e: KeyboardEvent<HTMLDivElement>, index: number, tag: string) => {
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault()
if (!bindingPending) toggleTag(tag)
return
}
if (allTagNames.length === 0) return
if (e.key === 'ArrowDown') {
e.preventDefault()
focusTagOption((index + 1) % allTagNames.length)
} else if (e.key === 'ArrowUp') {
e.preventDefault()
focusTagOption((index - 1 + allTagNames.length) % allTagNames.length)
} else if (e.key === 'Home') {
e.preventDefault()
focusTagOption(0)
} else if (e.key === 'End') {
e.preventDefault()
focusTagOption(allTagNames.length - 1)
}
}
const addNewTag = () => {
if (bindingPendingRef.current) return
const tag = tagInput.trim()
if (!tag || localTags.includes(tag)) {
if (!tag || localTag === tag) {
setTagInput('')
return
}
const prev = localTags
const next = [...prev, tag]
setLocalTags(next)
const prev = localTag
const next = tag
setLocalTag(next)
setTagInput('')
setBindingError(null)
void persistTags(next, prev)
void persistTag(next, prev)
}
return (
@@ -137,9 +173,7 @@ export function ResourceCardMenu({
label={t('library.action.manage_tags')}
suffix={
<>
{localTags.length > 0 && (
<span className="text-foreground-muted text-xs tabular-nums">{localTags.length}</span>
)}
{localTag && <span className="text-foreground-muted text-xs tabular-nums">1</span>}
<ChevronDown size={8} className={`transition-transform ${showTagPicker ? 'rotate-180' : ''}`} />
</>
}
@@ -174,43 +208,38 @@ export function ResourceCardMenu({
)}
</div>
<Separator className="mx-1 mb-0.5 bg-border-subtle" />
<div className="flex-1 overflow-y-auto [&::-webkit-scrollbar-thumb]:bg-border-muted [&::-webkit-scrollbar]:w-0.5">
<div
role="menu"
aria-label={t('library.config.basic.tags')}
className="flex-1 overflow-y-auto [&::-webkit-scrollbar-thumb]:bg-border-muted [&::-webkit-scrollbar]:w-0.5">
{allTagNames.length === 0 && !tagInput.trim() && (
<p className="px-2.5 py-2 text-center text-foreground-muted text-xs">
{t('library.tag_picker.no_tags')}
</p>
)}
{allTagNames.map((tag) => {
const checked = localTags.includes(tag)
{allTagNames.map((tag, index) => {
const checked = localTag === tag
return (
<div
key={tag}
role="button"
tabIndex={bindingPending ? -1 : 0}
ref={(node) => {
tagOptionRefs.current[index] = node
}}
role="menuitemradio"
aria-checked={checked}
tabIndex={!bindingPending && index === activeTagIndex ? 0 : -1}
aria-disabled={bindingPending || undefined}
onClick={() => toggleTag(tag)}
onKeyDown={(e) => {
if (bindingPending) return
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault()
toggleTag(tag)
}
}}
onFocus={() => setActiveTagIndex(index)}
onKeyDown={(e) => handleTagOptionKeyDown(e, index, tag)}
className={`flex w-full items-center gap-2 rounded-md px-2.5 py-1 text-foreground-secondary text-xs transition-colors ${
bindingPending
? 'cursor-not-allowed opacity-60'
: 'cursor-pointer hover:bg-accent hover:text-foreground'
}`}>
<span onClick={(e) => e.stopPropagation()}>
<Checkbox
size="sm"
checked={checked}
disabled={bindingPending}
onCheckedChange={() => toggleTag(tag)}
/>
</span>
<span className="h-1.5 w-1.5 shrink-0 rounded-full" style={{ backgroundColor: colorFor(tag) }} />
<span className="flex-1 truncate text-left">{tag}</span>
{checked && <Check size={12} className="shrink-0 text-success" />}
</div>
)
})}

View File

@@ -214,22 +214,13 @@ interface ResourceCardProps {
onDuplicate: (resource: ResourceItem) => void
onEdit: (resource: ResourceItem) => void
onExport: (resource: ResourceItem) => void
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
}
function hasOverflowActions(resource: ResourceItem) {
return resource.type === 'assistant'
}
export function ResourceCard({
resource: r,
allTagNames,
onDelete,
onDuplicate,
onEdit,
onExport,
onUpdateResourceTags
}: ResourceCardProps) {
export function ResourceCard({ resource: r, allTagNames, onDelete, onDuplicate, onEdit, onExport }: ResourceCardProps) {
const { t } = useTranslation()
const [menuOpen, setMenuOpen] = useState(false)
const cfg = RESOURCE_TYPE_META[r.type]
@@ -237,8 +228,7 @@ export function ResourceCard({
// other resources keep their own avatar on the neutral accent block.
const useTypedAvatarBg = r.type === 'skill'
const showOverflowMenu = hasOverflowActions(r)
const visibleTags = r.type === 'assistant' ? r.tags.slice(0, 2) : []
const extraTagCount = r.type === 'assistant' ? r.tags.length - visibleTags.length : 0
const visibleTag = r.type === 'assistant' ? r.tag : undefined
return (
<div
@@ -259,17 +249,13 @@ export function ResourceCard({
<div className="min-w-0 flex-1">
<h4 className="truncate font-medium text-foreground text-sm leading-5">{r.name}</h4>
<p className="mt-0.5 truncate text-foreground-secondary text-xs leading-4">{r.description}</p>
{visibleTags.length > 0 && (
{visibleTag && (
<div className="mt-1.5 flex min-w-0 items-center gap-1">
{visibleTags.map((tag, i) => (
<Badge
key={`${tag}-${i}`}
variant="secondary"
className="max-w-24 truncate border-0 bg-secondary px-1.5 py-px text-foreground-secondary text-xs">
{tag}
</Badge>
))}
{extraTagCount > 0 && <span className="shrink-0 text-foreground-muted text-xs">+{extraTagCount}</span>}
<Badge
variant="secondary"
className="max-w-24 truncate border-0 bg-secondary px-1.5 py-px text-foreground-secondary text-xs">
{visibleTag}
</Badge>
</div>
)}
</div>
@@ -299,7 +285,6 @@ export function ResourceCard({
onDuplicate={onDuplicate}
onDelete={onDelete}
onExport={onExport}
onUpdateResourceTags={onUpdateResourceTags}
allTagNames={allTagNames}
/>
</PopoverContent>

View File

@@ -69,8 +69,6 @@ interface Props {
onTagFilter: (tagName: string | null) => void
/** Create a new tag (POST /tags). Does not bind the tag to any resource. */
onAddTag: (tagName: string) => Promise<void> | void
/** Replace the tag-name set for a single resource. Caller handles ensure-tag + bind. */
onUpdateResourceTags: (resourceId: string, tags: string[]) => Promise<void> | void
allTagNames: string[]
/** Full backend tag records (id + name + color). Distinct from `allTagNames` (names only). */
allTags: BackendTag[]
@@ -128,7 +126,6 @@ export const ResourceGrid: FC<Props> = ({
activeTag,
onTagFilter,
onAddTag,
onUpdateResourceTags,
allTagNames,
allTags,
assistantCatalog
@@ -494,7 +491,6 @@ export const ResourceGrid: FC<Props> = ({
onDuplicate={onDuplicate}
onEdit={onEdit}
onExport={onExport}
onUpdateResourceTags={onUpdateResourceTags}
/>
)}
</div>
@@ -534,7 +530,6 @@ interface VirtualizedResourceGridProps {
onDuplicate: (r: ResourceItem) => void
onEdit: (r: ResourceItem) => void
onExport: (r: ResourceItem) => void
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
}
function VirtualizedResourceGrid({
@@ -545,8 +540,7 @@ function VirtualizedResourceGrid({
onDelete,
onDuplicate,
onEdit,
onExport,
onUpdateResourceTags
onExport
}: VirtualizedResourceGridProps) {
const rows = useMemo(() => {
const nextRows: ResourceItem[][] = []
@@ -590,7 +584,6 @@ function VirtualizedResourceGrid({
onDuplicate={onDuplicate}
onEdit={onEdit}
onExport={onExport}
onUpdateResourceTags={onUpdateResourceTags}
/>
))}
</div>

View File

@@ -75,27 +75,6 @@ vi.mock('@cherrystudio/ui', async () => {
</button>
)
},
Checkbox: ({
checked = false,
onCheckedChange,
size,
...props
}: Omit<ComponentProps<'button'>, 'onChange'> & {
checked?: boolean
onCheckedChange?: (checked: boolean) => void
size?: string
}) => {
void size
return (
<button
type="button"
role="checkbox"
aria-checked={checked}
onClick={() => onCheckedChange?.(!checked)}
{...props}
/>
)
},
ConfirmDialog: ({
cancelText,
confirmText,
@@ -276,7 +255,6 @@ function createAssistantResource(overrides: Partial<Extract<ResourceItem, { type
name: 'Assistant',
description: '',
avatar: 'A',
tags: [],
createdAt: '2026-05-06T00:00:00.000Z',
updatedAt: '2026-05-06T00:00:00.000Z',
raw: {} as Extract<ResourceItem, { type: 'assistant' }>['raw'],
@@ -291,7 +269,6 @@ function createAgentResource(): ResourceItem {
name: 'Agent',
description: '',
avatar: 'A',
tags: [],
createdAt: '2026-05-06T00:00:00.000Z',
updatedAt: '2026-05-06T00:00:00.000Z',
raw: {} as Extract<ResourceItem, { type: 'agent' }>['raw']
@@ -305,7 +282,6 @@ function createSkillResource(): ResourceItem {
name: 'Skill',
description: '',
avatar: 'S',
tags: [],
createdAt: '2026-05-06T00:00:00.000Z',
updatedAt: '2026-05-06T00:00:00.000Z',
raw: {} as Extract<ResourceItem, { type: 'skill' }>['raw']
@@ -319,7 +295,6 @@ function createPromptResource(): ResourceItem {
name: 'Prompt',
description: '',
avatar: 'Aa',
tags: [],
createdAt: '2026-05-06T00:00:00.000Z',
updatedAt: '2026-05-06T00:00:00.000Z',
raw: {} as Extract<ResourceItem, { type: 'prompt' }>['raw']
@@ -344,7 +319,6 @@ function renderResourceGrid(props: Partial<ComponentProps<typeof ResourceGrid>>
activeTag={null}
onTagFilter={vi.fn()}
onAddTag={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={[]}
allTags={[]}
{...props}
@@ -359,7 +333,6 @@ function getResourceCardProps(overrides: Partial<ComponentProps<typeof ResourceC
onDuplicate: vi.fn(),
onEdit: vi.fn(),
onExport: vi.fn(),
onUpdateResourceTags: vi.fn(),
...overrides
}
}
@@ -505,17 +478,12 @@ describe('ResourceGrid card actions', () => {
expect(onDelete).toHaveBeenCalledWith(resource)
})
it('keeps assistant tags visible in the compact card layout', () => {
render(
<ResourceCard
resource={createAssistantResource({ tags: ['alpha', 'beta', 'gamma'] })}
{...getResourceCardProps()}
/>
)
it('shows only one assistant tag in the compact card layout', () => {
render(<ResourceCard resource={createAssistantResource({ tag: 'alpha' })} {...getResourceCardProps()} />)
expect(screen.getByText('alpha')).toBeInTheDocument()
expect(screen.getByText('beta')).toBeInTheDocument()
expect(screen.getByText('+1')).toBeInTheDocument()
expect(screen.queryByText('beta')).not.toBeInTheDocument()
expect(screen.queryByText('+2')).not.toBeInTheDocument()
})
})
@@ -624,7 +592,6 @@ describe('ResourceCardMenu tag binding', () => {
const pendingTags = createDeferred<Array<{ id: string; name: string }>>()
ensureTagsMock.mockReturnValueOnce(pendingTags.promise)
updateAssistantMock.mockResolvedValue({})
const onUpdateResourceTags = vi.fn()
render(
<ResourceCardMenu
@@ -633,17 +600,20 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={onUpdateResourceTags}
allTagNames={['alpha', 'beta']}
/>
)
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
const checkboxes = screen.getAllByRole('checkbox')
await user.click(checkboxes[0])
expect(screen.getByRole('menu', { name: 'library.config.basic.tags' })).toContainElement(
screen.getByRole('menuitemradio', { name: 'alpha' })
)
await user.click(screen.getByRole('menuitemradio', { name: 'alpha' }))
await waitFor(() => expect(checkboxes[1]).toBeDisabled())
await user.click(checkboxes[1])
await waitFor(() =>
expect(screen.getByRole('menuitemradio', { name: 'beta' })).toHaveAttribute('aria-disabled', 'true')
)
await user.click(screen.getByRole('menuitemradio', { name: 'beta' }))
expect(ensureTagsMock).toHaveBeenCalledTimes(1)
pendingTags.resolve([{ id: 'tag-alpha', name: 'alpha' }])
@@ -651,10 +621,73 @@ describe('ResourceCardMenu tag binding', () => {
await waitFor(() => {
expect(updateAssistantMock).toHaveBeenCalledWith({ tagIds: ['tag-alpha'] })
})
expect(onUpdateResourceTags).toHaveBeenCalledWith('assistant-1', ['alpha'])
expect(ensureTagsMock).toHaveBeenCalledTimes(1)
})
it('uses roving focus for tag picker radio items', async () => {
const user = userEvent.setup()
render(
<ResourceCardMenu
resource={createAssistantResource()}
onClose={vi.fn()}
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
allTagNames={['alpha', 'beta', 'gamma']}
/>
)
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
const alpha = screen.getByRole('menuitemradio', { name: 'alpha' })
const beta = screen.getByRole('menuitemradio', { name: 'beta' })
const gamma = screen.getByRole('menuitemradio', { name: 'gamma' })
expect(alpha).toHaveAttribute('tabindex', '0')
expect(beta).toHaveAttribute('tabindex', '-1')
alpha.focus()
expect(alpha).toHaveFocus()
await user.keyboard('{ArrowDown}')
expect(beta).toHaveFocus()
await waitFor(() => {
expect(alpha).toHaveAttribute('tabindex', '-1')
expect(beta).toHaveAttribute('tabindex', '0')
})
await user.keyboard('{End}')
expect(gamma).toHaveFocus()
await user.keyboard('{Home}')
expect(alpha).toHaveFocus()
})
it('replaces the current assistant tag when a different tag is selected', async () => {
const user = userEvent.setup()
ensureTagsMock.mockResolvedValueOnce([{ id: 'tag-beta', name: 'beta' }])
updateAssistantMock.mockResolvedValue({})
render(
<ResourceCardMenu
resource={createAssistantResource({ tag: 'alpha' })}
onClose={vi.fn()}
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
allTagNames={['alpha', 'beta']}
/>
)
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
await user.click(screen.getByRole('menuitemradio', { name: 'beta' }))
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['beta']))
expect(updateAssistantMock).toHaveBeenCalledWith({ tagIds: ['tag-beta'] })
})
it('does not expose tag management for agent, skill, or prompt resources', () => {
const { rerender } = render(
<ResourceCardMenu
@@ -663,7 +696,6 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={['alpha', 'beta']}
/>
)
@@ -677,7 +709,6 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={['alpha', 'beta']}
/>
)
@@ -691,7 +722,6 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={['alpha', 'beta']}
/>
)
@@ -707,7 +737,6 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={[]}
/>
)
@@ -724,7 +753,6 @@ describe('ResourceCardMenu tag binding', () => {
onDuplicate={vi.fn()}
onDelete={vi.fn()}
onExport={vi.fn()}
onUpdateResourceTags={vi.fn()}
allTagNames={[]}
/>
)

View File

@@ -180,7 +180,7 @@ describe('useResourceLibrary model display names', () => {
})
const skill = result.current.allResources.find((resource) => resource.type === 'skill')
expect(skill?.tags).toEqual([])
expect(skill?.tag).toBeUndefined()
})
it('passes skill search to the backend and ignores activeTag', () => {
@@ -277,8 +277,7 @@ describe('useResourceLibrary model display names', () => {
type: 'prompt',
name: '日报模板',
description: '今日完成 ${task}',
avatar: 'Aa',
tags: []
avatar: 'Aa'
}
])
})

View File

@@ -84,13 +84,9 @@ export function useResourceLibrary({
// the param when nothing resolves rather than sending a 400.
const tagIds = useMemo(() => {
if (!assistantTagsActive) return undefined
const names = [activeTag].filter((x): x is string => Boolean(x))
if (names.length === 0) return undefined
const ids = names.flatMap((name) => {
const id = tagIdByName.get(name)
return id ? [id] : []
})
return ids.length > 0 ? ids : undefined
if (!activeTag) return undefined
const id = tagIdByName.get(activeTag)
return id ? [id] : undefined
}, [activeTag, assistantTagsActive, tagIdByName])
// Defensive guard for the rare race where the user has a chip selected but
@@ -112,10 +108,10 @@ export function useResourceLibrary({
const filteredPrompts = promptAdapter.useList(promptsVisible ? { search: trimmedSearch } : undefined)
const buildAssistantItem = useCallback((a: Assistant): ResourceItem => {
// Defensive `?? []`: schema declares tags as required, but stale DataApi
// cache or a row from a code path that bypasses the embed helper can
// still hand us undefined here. `.map` would throw.
const tags = a.tags ?? []
// Defensive optional access: schema declares tags as required, but stale DataApi
// cache or a row from a code path that bypasses the embed helper can still hand
// us undefined here.
const tag = a.tags?.[0]
return {
id: a.id,
type: 'assistant',
@@ -125,7 +121,7 @@ export function useResourceLibrary({
// Embedded by AssistantService.list via JOIN on user_model; null when the
// bound model row was removed.
model: a.modelName ?? undefined,
tags: tags.map((t) => t.name),
tag: tag?.name,
createdAt: a.createdAt,
updatedAt: a.updatedAt,
raw: a
@@ -140,7 +136,6 @@ export function useResourceLibrary({
description: a.description ?? '',
avatar: getAgentAvatarFromConfiguration(a.configuration),
model: a.modelName ?? undefined,
tags: [],
createdAt: a.createdAt,
updatedAt: a.updatedAt,
raw: a
@@ -157,7 +152,6 @@ export function useResourceLibrary({
avatar: '⚡',
// Skill metadata tags from SKILL.md live on `sourceTags`; the outer
// resource-library user tag concept is assistant-only.
tags: [],
createdAt: s.createdAt,
updatedAt: s.updatedAt,
raw: s
@@ -171,7 +165,6 @@ export function useResourceLibrary({
name: p.title,
description: p.content.replace(/\s+/g, ' ').trim(),
avatar: 'Aa',
tags: [],
createdAt: p.createdAt,
updatedAt: p.updatedAt,
raw: p

View File

@@ -19,17 +19,16 @@ interface ResourceItemBase<TType extends ResourceType, TRaw> {
description: string
avatar: string
model?: string
tags: string[]
createdAt: string
updatedAt: string
raw: TRaw
}
export type ResourceItem =
| ResourceItemBase<'assistant', Assistant>
| ResourceItemBase<'agent', AgentDetail>
| ResourceItemBase<'skill', InstalledSkill>
| ResourceItemBase<'prompt', Prompt>
| (ResourceItemBase<'assistant', Assistant> & { tag?: string })
| (ResourceItemBase<'agent', AgentDetail> & { tag?: never })
| (ResourceItemBase<'skill', InstalledSkill> & { tag?: never })
| (ResourceItemBase<'prompt', Prompt> & { tag?: never })
export interface TagItem {
id: string

View File

@@ -60,7 +60,7 @@ describe('assistantTransfer', () => {
{
name: '写作助手',
emoji: '✍️',
group: ['写作', '生产力'],
group: ['写作'],
prompt: 'You are helpful',
description: '擅长写作润色',
regularPhrases: [],
@@ -89,10 +89,7 @@ describe('assistantTransfer', () => {
// modelId is intentionally not part of the DTO — the backend fills it from
// the `chat.default_model_id` preference during create.
expect(draft.dto.modelId).toBeUndefined()
expect(draft.tags).toEqual([
{ name: '写作', color: null },
{ name: '生产力', color: null }
])
expect(draft.tags).toEqual([{ name: '写作', color: null }])
})
it('ignores v2-only fields from imported content and still uses legacy defaults', () => {

View File

@@ -55,6 +55,8 @@ function normalizeRecord(record: unknown): ImportedAssistantDraft {
throw new AssistantTransferError('invalid_format')
}
const tagName = readStringArray(record.group)[0]
// `modelId` is intentionally omitted — backend fills it from
// `chat.default_model_id` preference. See AssistantService.resolveCreateModelId.
return {
@@ -65,18 +67,17 @@ function normalizeRecord(record: unknown): ImportedAssistantDraft {
description: readString(record.description),
settings: DEFAULT_ASSISTANT_SETTINGS
},
tags: readStringArray(record.group).map((tagName) => ({
name: tagName,
color: null
}))
tags: tagName ? [{ name: tagName, color: null }] : []
}
}
function buildExportRecord(assistant: Assistant): AssistantExportRecord {
const tagName = assistant.tags[0]?.name
return {
name: assistant.name,
emoji: assistant.emoji,
group: assistant.tags.map((tag) => tag.name),
group: tagName ? [tagName] : [],
prompt: assistant.prompt,
description: assistant.description,
regularPhrases: [],

View File

@@ -122,16 +122,92 @@ describe('useAutoPullOnApiKeyChange', () => {
expect(onTrigger).not.toHaveBeenCalled()
})
it('does not fire when no models exist locally yet', () => {
it('fires on first render + key change for API-key providers with no models (auto-sync is disabled)', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
useModelsMock.mockReturnValue({ models: [] })
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
// First render with keys + no models: opens pull reconcile.
expect(onTrigger).toHaveBeenCalledTimes(1)
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-two'))
rerender()
// Key change with no models: opens pull reconcile again.
expect(onTrigger).toHaveBeenCalledTimes(2)
})
it('does not fire for non-key providers when no models exist (auto-sync handles bootstrap)', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
useModelsMock.mockReturnValue({ models: [] })
useProviderMock.mockReturnValue(providerWithHost('http://localhost:11434', 'ollama'))
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('ollama', onTrigger))
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-two'))
rerender()
expect(onTrigger).not.toHaveBeenCalled()
})
it('fires on first render for API-key providers with no models and enabled keys', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
useModelsMock.mockReturnValue({ models: [] })
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
expect(onTrigger).toHaveBeenCalledTimes(1)
})
it('does not fire on first render for API-key providers that already have models', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
useModelsMock.mockReturnValue({ models: [{ id: 'openai::gpt-4o' }] })
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
expect(onTrigger).not.toHaveBeenCalled()
})
it('does not fire on first render for API-key providers without enabled keys', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(emptyApiKeys())
useModelsMock.mockReturnValue({ models: [] })
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
expect(onTrigger).not.toHaveBeenCalled()
})
it('does not fire on first render for non-key providers with no models (auto-sync handles it)', () => {
const onTrigger = vi.fn()
useProviderApiKeysMock.mockReturnValue(emptyApiKeys())
useModelsMock.mockReturnValue({ models: [] })
useProviderMock.mockReturnValue(providerWithHost('http://localhost:11434', 'ollama'))
renderHook(() => useAutoPullOnApiKeyChange('ollama', onTrigger))
expect(onTrigger).not.toHaveBeenCalled()
})
it('waits for models to finish loading before deciding first-render pull reconcile', () => {
const onTrigger = vi.fn()
// api-keys resolve first, models are still loading.
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
useModelsMock.mockReturnValue({ models: [], isLoading: true })
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
expect(onTrigger).not.toHaveBeenCalled()
// models resolve later — the provider already has local models.
useModelsMock.mockReturnValue({ models: [{ id: 'openai::gpt-4o' }], isLoading: false })
rerender()
expect(onTrigger).not.toHaveBeenCalled()
})

View File

@@ -7,16 +7,16 @@ import { providerNeedsApiKeyForModelSync } from './providerModelSyncRequirements
/**
* Fires `onTrigger` once whenever the provider's enabled API-key fingerprint OR
* its host (endpoint/baseUrl/authType) changes — but only after the first render
* and only when local models already exist (first-time bootstrap is owned by
* `useProviderAutoModelSync`). A pull still requires at least one enabled key
* for providers whose model sync needs API-key auth, so disabling the only key
* never fires for those providers.
* its host (endpoint/baseUrl/authType) changes. For API-key providers this also
* fires on first render when no local models exist (first-time bootstrap uses
* the pull-reconcile sidebar instead of direct auto-sync). A pull still requires
* at least one enabled key for providers whose model sync needs API-key auth,
* so disabling the only key never fires for those providers.
*/
export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => void | Promise<void>) {
const { provider } = useProvider(providerId)
const { data: apiKeysData } = useProviderApiKeys(providerId)
const { models } = useModels({ providerId })
const { models, isLoading } = useModels({ providerId })
const enabledKeySignature = useMemo(
() =>
@@ -47,12 +47,12 @@ export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => v
}, [onTrigger])
useEffect(() => {
// Until provider/api-keys resolve the signature is a cold-cache placeholder;
// recording that as the baseline would make the later undefined→loaded
// transition look like a user-initiated change and auto-fire the pull.
if (!provider || apiKeysData === undefined) return
if (!provider || apiKeysData === undefined || isLoading) return
if (lastSignatureRef.current === null) {
lastSignatureRef.current = changeSignature
if (models.length === 0 && requiresApiKeyForModelSync && enabledKeySignature) {
void onTriggerRef.current()
}
return
}
if (lastSignatureRef.current === changeSignature) {
@@ -61,7 +61,15 @@ export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => v
lastSignatureRef.current = changeSignature
// Key-required providers still need an enabled key; disabling the only key must not fire.
if (requiresApiKeyForModelSync && !enabledKeySignature) return
if (models.length === 0) return
if (models.length === 0 && !requiresApiKeyForModelSync) return
void onTriggerRef.current()
}, [apiKeysData, changeSignature, enabledKeySignature, models.length, provider, requiresApiKeyForModelSync])
}, [
apiKeysData,
changeSignature,
enabledKeySignature,
isLoading,
models.length,
provider,
requiresApiKeyForModelSync
])
}

View File

@@ -80,47 +80,58 @@ describe('useProviderAutoModelSync', () => {
expect(useProviderModelSyncMock).toHaveBeenCalledWith('openai', { existingModels: [] })
})
it('enables a disabled provider when auto sync returns at least one model', async () => {
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'openai::gpt-4o' }])
it('skips auto sync for API-key providers (uses pull reconcile instead)', async () => {
renderHook(() => useProviderAutoModelSync('openai'))
await waitFor(() => expect(updateProviderMock).toHaveBeenCalledWith({ isEnabled: true }))
await waitFor(() =>
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
providerId: 'openai',
reason: 'uses_pull_reconcile'
})
)
expect(syncProviderModelsMock).not.toHaveBeenCalled()
})
it('keeps a disabled provider disabled when auto sync returns zero models', async () => {
syncProviderModelsMock.mockResolvedValueOnce([])
renderHook(() => useProviderAutoModelSync('openai'))
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
await Promise.resolve()
expect(updateProviderMock).not.toHaveBeenCalled()
})
it('does not patch an already enabled provider after successful auto sync', async () => {
it('auto syncs for non-key providers (e.g. Ollama) when models are missing', async () => {
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'ollama::llama3.2' }])
useProviderMock.mockReturnValue({
provider: {
id: 'openai',
isEnabled: true,
defaultChatEndpoint: 'openai_chat_completions',
id: 'ollama',
isEnabled: false,
defaultChatEndpoint: 'ollama_chat',
endpointConfigs: {
openai_chat_completions: { baseUrl: 'https://api.openai.com/v1' }
ollama_chat: { baseUrl: 'http://localhost:11434' }
}
},
updateProvider: updateProviderMock
})
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'openai::gpt-4o' }])
useProviderApiKeysMock.mockReturnValue({
data: { keys: [] }
})
renderHook(() => useProviderAutoModelSync('openai'))
renderHook(() => useProviderAutoModelSync('ollama'))
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
await Promise.resolve()
expect(updateProviderMock).not.toHaveBeenCalled()
await waitFor(() => expect(updateProviderMock).toHaveBeenCalledWith({ isEnabled: true }))
})
it('syncs only once for the same initial eligible configuration', async () => {
const { rerender } = renderHook(() => useProviderAutoModelSync('openai'))
it('syncs only once for the same initial eligible configuration (non-key provider)', async () => {
useProviderMock.mockReturnValue({
provider: {
id: 'ollama',
isEnabled: false,
defaultChatEndpoint: 'ollama_chat',
endpointConfigs: {
ollama_chat: { baseUrl: 'http://localhost:11434' }
}
},
updateProvider: updateProviderMock
})
useProviderApiKeysMock.mockReturnValue({
data: { keys: [] }
})
syncProviderModelsMock.mockResolvedValue([])
const { rerender } = renderHook(() => useProviderAutoModelSync('ollama'))
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
@@ -154,19 +165,18 @@ describe('useProviderAutoModelSync', () => {
expect(syncProviderModelsMock).not.toHaveBeenCalled()
})
it('logs auto sync failures and allows retrying when the same signature becomes eligible again', async () => {
const syncError = new Error('sync down')
syncProviderModelsMock.mockRejectedValueOnce(syncError).mockResolvedValueOnce([])
it('skips auto sync for API-key provider even after key rotation (pull reconcile handles it)', async () => {
const { rerender } = renderHook(() => useProviderAutoModelSync('openai'))
// First render with keys → uses_pull_reconcile
await waitFor(() =>
expect(loggerErrorMock).toHaveBeenCalledWith('Provider auto model sync failed', {
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
providerId: 'openai',
error: syncError
reason: 'uses_pull_reconcile'
})
)
// Keys removed
useProviderApiKeysMock.mockReturnValue({
data: { keys: [] }
})
@@ -179,12 +189,18 @@ describe('useProviderAutoModelSync', () => {
})
)
// Keys restored — still uses pull reconcile, no direct sync
useProviderApiKeysMock.mockReturnValue({
data: { keys: [{ id: 'key-1', key: 'sk-test', isEnabled: true }] }
})
rerender()
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(2))
expect(updateProviderMock).not.toHaveBeenCalled()
await waitFor(() =>
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
providerId: 'openai',
reason: 'uses_pull_reconcile'
})
)
expect(syncProviderModelsMock).not.toHaveBeenCalled()
})
})

View File

@@ -69,6 +69,13 @@ export function useProviderAutoModelSync(providerId: string) {
} as const
}
if (requiresApiKeyForModelSync) {
return {
shouldSync: false,
reason: 'uses_pull_reconcile'
} as const
}
if (!initialModelSyncSignature) {
return {
shouldSync: false,

View File

@@ -30,12 +30,15 @@ import {
} from '@cherrystudio/ui'
import { loggerService } from '@logger'
import ListItem from '@renderer/components/ListItem'
import { WorkspaceSelector } from '@renderer/components/resource'
import Scrollbar from '@renderer/components/Scrollbar'
import { dataApiService } from '@renderer/data/DataApiService'
import { useQuery } from '@renderer/data/hooks/useDataApi'
import { useChannels } from '@renderer/hooks/agent/useChannels'
import { useCreateTask, useDeleteTask, useRunTask, useTaskLogs, useUpdateTask } from '@renderer/hooks/agent/useTasks'
import { useConversationNavigation } from '@renderer/hooks/useConversationNavigation'
import { useTheme } from '@renderer/hooks/useTheme'
import { AGENT_WORKSPACE_TYPE } from '@shared/data/api/schemas/agentWorkspaces'
import type { Trigger } from '@shared/data/api/schemas/jobs'
import type {
AgentEntity,
@@ -47,8 +50,11 @@ import type {
import {
AlertTriangle,
CalendarClock,
ChevronDown,
CircleSlash,
Clock,
ExternalLink,
Folder,
History,
Maximize2,
MoreHorizontal,
@@ -324,6 +330,15 @@ const TaskDetail: FC<{
timeoutMinutes: task.timeoutMinutes?.toString() ?? ''
})
const [channelIds, setChannelIds] = useState<string[]>(task.channelIds ?? [])
const [workspaceId, setWorkspaceId] = useState<string | null>(
task.workspace.type === AGENT_WORKSPACE_TYPE.USER ? task.workspace.workspaceId : null
)
const { data: workspaces } = useQuery('/agent-workspaces')
const isSystemWorkspace = workspaceId === null
const workspaceLabel = isSystemWorkspace
? t('agent.session.workspace_selector.no_project')
: (workspaces?.find((w) => w.id === workspaceId)?.name ?? workspaceId)
const toggleStatusLabel =
task.status === 'active' ? t('agent.cherryClaw.tasks.pause') : t('agent.cherryClaw.tasks.resume')
@@ -337,6 +352,7 @@ const TaskDetail: FC<{
const next = triggerToFormState(task.trigger)
setSchedule({ ...next, timeoutMinutes: task.timeoutMinutes?.toString() ?? '' })
setChannelIds(task.channelIds ?? [])
setWorkspaceId(task.workspace.type === AGENT_WORKSPACE_TYPE.USER ? task.workspace.workspaceId : null)
}, [task])
const saveField = useCallback(
@@ -528,6 +544,36 @@ const TaskDetail: FC<{
}}
disabled={isCompleted}
/>
{/* Workspace is a secondary detail — scheduled tasks default to "No work directory". */}
<div className="flex items-center gap-1.5 text-foreground-muted text-xs">
<span>{t('agent.session.display.workdir')}</span>
<WorkspaceSelector
value={workspaceId}
onChange={(nextWorkspaceId) => {
setWorkspaceId(nextWorkspaceId)
saveField({
workspace:
nextWorkspaceId === null
? { type: AGENT_WORKSPACE_TYPE.SYSTEM }
: { type: AGENT_WORKSPACE_TYPE.USER, workspaceId: nextWorkspaceId }
})
}}
disabled={isCompleted}
align="start"
trigger={
<Button
variant="ghost"
size="sm"
className="h-6 gap-1 px-1.5 text-foreground-muted"
disabled={isCompleted}>
{isSystemWorkspace ? <CircleSlash className="size-3.5" /> : <Folder className="size-3.5" />}
<span className="max-w-40 truncate">{workspaceLabel}</span>
<ChevronDown className="size-3.5" />
</Button>
}
/>
</div>
</div>
</SettingGroup>
@@ -783,14 +829,20 @@ const CreateForm: FC<{
const [promptModalOpen, setPromptModalOpen] = useState(false)
const [schedule, setSchedule] = useState<ScheduleFormState>({ kind: 'interval', value: '', timeoutMinutes: '' })
const [channelIds, setChannelIds] = useState<string[]>([])
// TODO(agent-workspace-picker): wire the workspace picker before re-enabling task creation.
const [workspaceSource] = useState<CreateTaskRequest['workspace'] | null>(null)
// `null` = "No work directory" (system workspace); a string binds the task to that user workspace.
const [workspaceId, setWorkspaceId] = useState<string | null>(null)
const { data: workspaces } = useQuery('/agent-workspaces')
const [saving, setSaving] = useState(false)
const isValid = agentId && name.trim() && prompt.trim() && schedule.value.trim() && workspaceSource
const isSystemWorkspace = workspaceId === null
const workspaceLabel = isSystemWorkspace
? t('agent.session.workspace_selector.no_project')
: (workspaces?.find((w) => w.id === workspaceId)?.name ?? workspaceId)
const isValid = agentId && name.trim() && prompt.trim() && schedule.value.trim()
const handleCreate = useCallback(async () => {
if (!agentId || !name.trim() || !prompt.trim() || !schedule.value.trim() || !workspaceSource) return
if (!agentId || !name.trim() || !prompt.trim() || !schedule.value.trim()) return
const trigger = formStateToTrigger(schedule.kind, schedule.value.trim())
if (!trigger) return
setSaving(true)
@@ -800,14 +852,17 @@ const CreateForm: FC<{
name: name.trim(),
prompt: prompt.trim(),
trigger,
workspace: workspaceSource,
workspace:
workspaceId === null
? { type: AGENT_WORKSPACE_TYPE.SYSTEM }
: { type: AGENT_WORKSPACE_TYPE.USER, workspaceId },
timeoutMinutes: timeout && timeout > 0 ? timeout : undefined,
channelIds: channelIds.length > 0 ? channelIds : undefined
})
} finally {
setSaving(false)
}
}, [agentId, name, prompt, schedule, workspaceSource, channelIds, onCreate])
}, [agentId, name, prompt, schedule, workspaceId, channelIds, onCreate])
return (
<SettingsContentColumn theme={theme}>
@@ -880,6 +935,23 @@ const CreateForm: FC<{
<TaskScheduleControls value={schedule} onChange={setSchedule} />
<TaskChannelSelector channels={channels} channelIds={channelIds} onChange={setChannelIds} />
{/* Workspace is a secondary detail — scheduled tasks default to "No work directory". */}
<div className="flex items-center gap-1.5 text-foreground-muted text-xs">
<span>{t('agent.session.display.workdir')}</span>
<WorkspaceSelector
value={workspaceId}
onChange={setWorkspaceId}
align="start"
trigger={
<Button variant="ghost" size="sm" className="h-6 gap-1 px-1.5 text-foreground-muted">
{isSystemWorkspace ? <CircleSlash className="size-3.5" /> : <Folder className="size-3.5" />}
<span className="max-w-40 truncate">{workspaceLabel}</span>
<ChevronDown className="size-3.5" />
</Button>
}
/>
</div>
<div className="flex gap-2">
<Button variant="outline" size="sm" onClick={onCancel}>
{t('agent.cherryClaw.tasks.cancel')}

View File

@@ -71,6 +71,14 @@ vi.mock('@renderer/hooks/agent/useChannels', () => ({
useChannels: () => ({ channels: [] })
}))
vi.mock('@renderer/data/hooks/useDataApi', () => ({
useQuery: () => ({ data: [] })
}))
vi.mock('@renderer/components/resource', () => ({
WorkspaceSelector: ({ trigger }: { trigger: React.ReactNode }) => <>{trigger}</>
}))
vi.mock('@renderer/hooks/agent/useTasks', () => ({
useCreateTask: () => ({ createTask: taskMutationMocks.createTask }),
useDeleteTask: () => ({ deleteTask: taskMutationMocks.deleteTask }),

View File

@@ -6,6 +6,8 @@ import {
KB_SEARCH_TOOL_NAME,
kbListInputSchema,
kbListStrictInputSchema,
kbManageInputSchema,
kbManageStrictInputSchema,
kbSearchInputSchema,
REPORT_ARTIFACTS_DESCRIPTION,
REPORT_ARTIFACTS_TOOL_NAME,
@@ -62,6 +64,41 @@ describe('builtin tool contracts', () => {
expect(kbListInputSchema.safeParse({ query: 'recipes' }).success).toBe(true)
})
it('keeps kb_manage strict-path fields in `required` so strict providers accept the schema', () => {
// Same regression as kb_list above: the AI-SDK path (KnowledgeManageTool) runs strict:true, so
// an all-optional object would serialize `required` away to nothing and a strict OpenAI-compatible
// provider would reject the whole request. The strict variant makes every optional field
// `.nullable()` (null = unused for this action/type) so they all stay in `required`.
const json = z.toJSONSchema(kbManageStrictInputSchema) as { required?: unknown }
expect(Array.isArray(json.required)).toBe(true)
expect(json.required).toEqual(
expect.arrayContaining(['baseId', 'action', 'type', 'path', 'url', 'content', 'title', 'conceptIds'])
)
// null is the "unused" signal for every optional field; an explicit all-null payload must still parse.
expect(
kbManageStrictInputSchema.safeParse({
baseId: 'kb-1',
action: 'delete',
type: null,
path: null,
url: null,
content: null,
title: null,
conceptIds: null
}).success
).toBe(true)
})
it('lets the MCP kb_manage path omit unused fields', () => {
// The Claude Code bridge parses raw args with kbManageInputSchema; an agent may omit every
// field but `baseId`/`action`, so the optional shape must accept that without erroring.
expect(kbManageInputSchema.safeParse({ baseId: 'kb-1', action: 'delete' }).success).toBe(true)
expect(kbManageInputSchema.safeParse({ baseId: 'kb-1', action: 'add', type: 'note', content: 'hi' }).success).toBe(
true
)
})
it('validates final report artifacts', () => {
const result = reportArtifactsInputSchema.parse({
artifacts: [{ path: 'dist/report.pdf', description: 'Final report' }],

View File

@@ -279,6 +279,11 @@ export const KB_MANAGE_ADD_TYPES = ['file', 'url', 'note'] as const
// One flat object, not a discriminated union: which fields apply depends on `action`
// (and, for add, on `type`). The core validates the combination and returns a steer
// string on a missing field, so the model gets a clear error rather than a schema reject.
//
// kb_manage is consumed by two paths with conflicting schema needs, same split as kb_list.
//
// MCP / Claude Code bridge (cherryBuiltinTools): the agent parses raw args with this schema and may
// omit any field, so they are `.optional()`.
export const kbManageInputSchema = z.object({
baseId: z.string().trim().min(1).describe('ID of the knowledge base to modify — a base id from kb_list.'),
action: z
@@ -319,6 +324,59 @@ export const kbManageInputSchema = z.object({
)
})
// AI-SDK path (KnowledgeManageTool) runs with `strict: true` — same `.nullable()` treatment as
// `kbListStrictInputSchema` and for the same reason (an all-optional shape serializes `required`
// away to nothing, which a strict OpenAI-compatible provider rejects). `manageKnowledge` treats
// null (like undefined) as "field not set for this action/type".
export const kbManageStrictInputSchema = z.object({
baseId: z.string().trim().min(1).describe('ID of the knowledge base to modify — a base id from kb_list.'),
action: z
.enum(KB_MANAGE_ACTIONS)
.describe(
'add: import a new source (set `type` + its field). delete: remove documents by `conceptIds`. ' +
'refresh: re-index documents by `conceptIds`. All actions modify the base and require user approval.'
),
type: z
.enum(KB_MANAGE_ADD_TYPES)
.nullable()
.describe(
'For action="add" only: the source kind — "file" (set `path`), "url" (set `url`), or "note" (set `content`). ' +
'Pass null otherwise.'
),
path: z
.string()
.trim()
.min(1)
.nullable()
.describe('For action="add", type="file": absolute local filesystem path of the file to import. Else null.'),
url: z
.string()
.trim()
.min(1)
.nullable()
.describe('For action="add", type="url": the URL to fetch and index. Else null.'),
content: z
.string()
.min(1)
.nullable()
.describe('For action="add", type="note": the plain-text note content to index. Else null.'),
title: z
.string()
.trim()
.min(1)
.nullable()
.describe(
'For action="add", type="note": optional display title (defaults to the note\'s first line). Pass null to omit.'
),
conceptIds: z
.array(z.string().trim().min(1))
.nullable()
.describe(
'For action="delete"/"refresh": Concept IDs (the `conceptId` field of a kb_search hit or a kb_list result) ' +
'to operate on. Else null.'
)
})
export const kbManageOutputSchema = z.object({
action: z.enum(KB_MANAGE_ACTIONS),
// add: the source identifiers that were imported (one per add call).
@@ -330,7 +388,6 @@ export const kbManageOutputSchema = z.object({
notFound: z.array(z.string()).optional()
})
export type KbManageInput = z.infer<typeof kbManageInputSchema>
export type KbManageOutput = z.infer<typeof kbManageOutputSchema>
// ── web_search ───────────────────────────────────────────────────