mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-07-03 12:27:41 +08:00
Merge branch 'main' into fix/image-generation-empty-panel
This commit is contained in:
@@ -305,6 +305,28 @@ describe('runAgentTask', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// Regular tasks carry the workspace bound at creation time (system by
|
||||
// default, since the picker defaults there) straight through to the session.
|
||||
it('binds a non-heartbeat task to the workspace bound on the task', async () => {
|
||||
vi.mocked(jobService.getById).mockReturnValueOnce(makeJobSnapshot())
|
||||
vi.mocked(jobScheduleService.getById).mockReturnValueOnce(makeSchedule('daily-summary'))
|
||||
vi.mocked(agentService.getAgent).mockReturnValueOnce(makeAgent())
|
||||
vi.mocked(agentSessionService.create).mockReturnValueOnce(makeSession('/ws/a'))
|
||||
|
||||
const promise = runAgentTask(
|
||||
makeCtx({ input: { agentId: 'a1', prompt: 'hi', timeoutMinutes: 0, workspace: { type: 'system' } } })
|
||||
)
|
||||
await vi.waitFor(() => expect(mockStartRun).toHaveBeenCalled())
|
||||
captured.listeners[0].onDone({ status: 'completed' })
|
||||
await promise
|
||||
|
||||
expect(agentSessionService.create).toHaveBeenCalledWith({
|
||||
agentId: 'a1',
|
||||
name: 'daily-summary',
|
||||
workspace: { type: 'system' }
|
||||
})
|
||||
})
|
||||
|
||||
// C1 (agents-jobs-3): a `text-delta` chunk's payload is on `.delta`, not `.text`.
|
||||
// The previous `as { text }` cast silently accumulated nothing, so every run
|
||||
// persisted the `'Completed'` fallback instead of the model's reply.
|
||||
|
||||
@@ -138,6 +138,9 @@ export async function runAgentTask(ctx: JobContext<AgentTaskInput>): Promise<Age
|
||||
// Always create a fresh session per fire. Scheduled tasks are discrete
|
||||
// invocations; cross-fire session reuse would only carry stale model
|
||||
// context. Persistent state lives in workspace files (heartbeat.md, etc.).
|
||||
// The session inherits the workspace bound on the task at creation time —
|
||||
// system for regular tasks (the picker defaults there), the validated user
|
||||
// workspace for heartbeats.
|
||||
const session = agentSessionService.create({
|
||||
agentId,
|
||||
name: taskName ?? 'Scheduled task',
|
||||
|
||||
@@ -9,9 +9,14 @@
|
||||
* mutation itself lives in the shared `knowledgeLookup` core so the Claude Code
|
||||
* MCP bridge runs identical logic (gated there by Claude Code's own permission
|
||||
* prompt); this file is just the AI-SDK `tool()` wrapper.
|
||||
*
|
||||
* `defer: 'never'` (kept inline, never behind `tool_search`/`tool_invoke`): the same rule
|
||||
* `mcp/mcpTools.ts` applies to force-prompt MCP tools. Deferring an approval-gated tool would strip
|
||||
* it from the SDK's tool-set, so the SDK's native `needsApproval` gate never fires and `tool_invoke`
|
||||
* refuses it too (it never runs an approval-gated tool blind) — an unreachable tool either way.
|
||||
*/
|
||||
|
||||
import { KB_MANAGE_TOOL_NAME, kbManageInputSchema, kbManageOutputSchema } from '@shared/ai/builtinTools'
|
||||
import { KB_MANAGE_TOOL_NAME, kbManageOutputSchema, kbManageStrictInputSchema } from '@shared/ai/builtinTools'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
@@ -31,7 +36,7 @@ const knowledgeManageResultSchema = z.union([kbManageOutputSchema, knowledgeLook
|
||||
|
||||
const kbManageTool = tool({
|
||||
description: KNOWLEDGE_MANAGE_DESCRIPTION,
|
||||
inputSchema: kbManageInputSchema,
|
||||
inputSchema: kbManageStrictInputSchema,
|
||||
outputSchema: knowledgeManageResultSchema,
|
||||
strict: true,
|
||||
// Every action (add / delete / refresh) modifies the base; gate on explicit user approval.
|
||||
@@ -48,7 +53,7 @@ export function createKbManageToolEntry(): ToolEntry {
|
||||
name: KB_MANAGE_TOOL_NAME,
|
||||
namespace: 'kb',
|
||||
description: 'Add, delete, or re-index documents in a knowledge base (requires approval)',
|
||||
defer: 'always',
|
||||
defer: 'never',
|
||||
tool: kbManageTool,
|
||||
applies: (scope) => scope.hasAnyKnowledgeBase === true && (scope.assistant?.knowledgeBaseIds?.length ?? 0) > 0
|
||||
}
|
||||
|
||||
@@ -69,7 +69,10 @@ describe('kb_manage', () => {
|
||||
it('builds an entry with the agreed namespace + defer policy and is approval-gated', () => {
|
||||
expect(entry.name).toBe(KB_MANAGE_TOOL_NAME)
|
||||
expect(entry.namespace).toBe('kb')
|
||||
expect(entry.defer).toBe('always')
|
||||
// Approval-gated tools must stay inline (never deferred) — see applyDeferExposition/toolInvoke:
|
||||
// a deferred approval-gated tool is unreachable (stripped from the inline set, and tool_invoke
|
||||
// refuses to run an approval-gated tool blind).
|
||||
expect(entry.defer).toBe('never')
|
||||
// Every action mutates the base, so the tool must require user approval.
|
||||
expect(entry.tool.needsApproval).toBe(true)
|
||||
})
|
||||
|
||||
@@ -34,4 +34,20 @@ describe('registerBuiltinTools', () => {
|
||||
expect(readFile?.applies?.({ mcpToolIds: new Set(), hasFileAttachments: false })).toBe(false)
|
||||
expect(readFile?.applies?.({ mcpToolIds: new Set(), hasFileAttachments: true })).toBe(true)
|
||||
})
|
||||
|
||||
it(
|
||||
'never defers an approval-gated entry (would strip it from the inline set with no way back — ' +
|
||||
'see mcp/mcpTools.ts and toolInvoke.ts for the same rule on MCP force-prompt tools)',
|
||||
() => {
|
||||
const reg = new ToolRegistry()
|
||||
registerBuiltinTools(reg)
|
||||
for (const entry of reg.getAll()) {
|
||||
if (entry.tool.needsApproval) {
|
||||
expect(entry.defer).toBe('never')
|
||||
}
|
||||
}
|
||||
// Sanity: this loop is only meaningful while at least one builtin entry is approval-gated.
|
||||
expect(reg.getAll().some((e) => e.tool.needsApproval)).toBe(true)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -25,7 +25,6 @@ import type {
|
||||
KbGrepOutput,
|
||||
KbListOutput,
|
||||
KbListOutputItem,
|
||||
KbManageInput,
|
||||
KbManageOutput,
|
||||
KbReadInput,
|
||||
KbReadOutput,
|
||||
@@ -415,6 +414,18 @@ export async function listOrOutlineKnowledge(
|
||||
/** Longest a derived note title (its first line) may be before it is truncated. */
|
||||
const NOTE_TITLE_MAX_CHARS = 80
|
||||
|
||||
/** kb_manage input shape shared by both callers: MCP omits an unused field, AI-SDK strict passes null. */
|
||||
type ManageKnowledgeInput = {
|
||||
baseId: string
|
||||
action: 'add' | 'delete' | 'refresh'
|
||||
type?: 'file' | 'url' | 'note' | null
|
||||
path?: string | null
|
||||
url?: string | null
|
||||
content?: string | null
|
||||
title?: string | null
|
||||
conceptIds?: string[] | null
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a destructive knowledge-base change (add / delete / refresh). Like the
|
||||
* read cores it never throws: an out-of-scope base, a missing required field, an
|
||||
@@ -425,7 +436,7 @@ const NOTE_TITLE_MAX_CHARS = 80
|
||||
* executes the mutation unconditionally once invoked.
|
||||
*/
|
||||
export async function manageKnowledge(
|
||||
input: KbManageInput,
|
||||
input: ManageKnowledgeInput,
|
||||
allowedIds: string[]
|
||||
): Promise<KnowledgeManageResultOrError> {
|
||||
if (allowedIds.length > 0 && !allowedIds.includes(input.baseId)) {
|
||||
@@ -493,7 +504,7 @@ type AddInputResult = { ok: true; input: KnowledgeAddItemInput; source: string }
|
||||
* invalid value (e.g. a non-absolute file path) is rejected before it reaches the
|
||||
* filesystem boundary. `source` is the identifier reported back as `added`.
|
||||
*/
|
||||
function buildAddInput(input: KbManageInput): AddInputResult {
|
||||
function buildAddInput(input: ManageKnowledgeInput): AddInputResult {
|
||||
switch (input.type) {
|
||||
case 'file': {
|
||||
if (!input.path) {
|
||||
@@ -540,7 +551,7 @@ function firstNonEmptyLine(content: string): string | undefined {
|
||||
}
|
||||
|
||||
/** A note's display source: the caller-supplied title, else its first non-empty line (truncated), else a placeholder. */
|
||||
function deriveNoteSource(content: string, title?: string): string {
|
||||
function deriveNoteSource(content: string, title?: string | null): string {
|
||||
const explicit = title?.trim()
|
||||
if (explicit) return explicit
|
||||
// Truncation here differs by role from deriveSampleSource's note branch (a stored id, plain-clipped;
|
||||
|
||||
@@ -982,10 +982,10 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
// fresh uuid dir, the runtime never opens a store mid-migration, and the catch below wipes a
|
||||
// partial on a caught failure. A crash-orphaned dir is never referenced by a knowledge_base row
|
||||
// so it is never mounted (it is dead disk, the same as the rename path produced).
|
||||
const driver = await openBetterSqlite3IndexDriver(plan.targetDbPath)
|
||||
const driver = openBetterSqlite3IndexDriver(plan.targetDbPath)
|
||||
try {
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
await ensureIndexMeta(driver, { baseId: plan.baseId })
|
||||
createKnowledgeIndexSchema(driver)
|
||||
ensureIndexMeta(driver, { baseId: plan.baseId })
|
||||
const store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
|
||||
|
||||
for (const material of plan.materials) {
|
||||
@@ -998,11 +998,11 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
// Fold the WAL back into the main db file so the committed pages are durable in index.sqlite
|
||||
// itself (the WAL is not guaranteed to be checkpointed on close); the runtime then opens a
|
||||
// self-contained store.
|
||||
await driver.execute('PRAGMA wal_checkpoint(TRUNCATE)')
|
||||
driver.execute('PRAGMA wal_checkpoint(TRUNCATE)')
|
||||
} finally {
|
||||
// Close so the file handle is released (a leaked handle would block a re-run's
|
||||
// removeIndexStoreFiles and the later base-dir deletion on Windows).
|
||||
await driver.close()
|
||||
driver.close()
|
||||
}
|
||||
// Build + close succeeded: the complete store sits at its runtime path. A later failure
|
||||
// (snapshot-pin) leaves it present and searchable, so it must NOT be wiped or marked failed.
|
||||
@@ -1166,11 +1166,11 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
// busy_timeout=5000, so this re-read of the just-built store waits out a transient Windows lock
|
||||
// (Defender / indexer scanning the freshly-written file) instead of throwing SQLITE_BUSY /
|
||||
// EACCES and failing validation for an already-correct store.
|
||||
const driver = await openBetterSqlite3IndexDriver(plan.targetDbPath)
|
||||
const driver = openBetterSqlite3IndexDriver(plan.targetDbPath)
|
||||
try {
|
||||
const materialCount = await this.tableCount(driver, 'material')
|
||||
const unitCount = await this.tableCount(driver, 'search_unit')
|
||||
const embeddingCount = await this.tableCount(driver, 'embedding')
|
||||
const materialCount = this.tableCount(driver, 'material')
|
||||
const unitCount = this.tableCount(driver, 'search_unit')
|
||||
const embeddingCount = this.tableCount(driver, 'embedding')
|
||||
targetCount += unitCount
|
||||
|
||||
this.pushCountMismatch(errors, plan.baseId, 'material', plan.materials.length, materialCount)
|
||||
@@ -1180,7 +1180,7 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
// Every unit's body search_text must resolve to a stored embedding, or that
|
||||
// unit is silently absent from vector search. This is the migration-time
|
||||
// form of the rebuild self-heal invariant (knowledge-technical-design.md §10).
|
||||
const uncovered = await driver.execute(
|
||||
const uncovered = driver.execute(
|
||||
`SELECT count(*) AS count FROM search_text st
|
||||
LEFT JOIN embedding e ON e.embedding_text_hash = st.embedding_text_hash
|
||||
WHERE e.embedding_text_hash IS NULL`
|
||||
@@ -1195,7 +1195,7 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
await driver.close()
|
||||
driver.close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1234,8 +1234,8 @@ export class KnowledgeVectorMigrator extends BaseMigrator {
|
||||
}
|
||||
}
|
||||
|
||||
private async tableCount(driver: BetterSqlite3Driver, table: string): Promise<number> {
|
||||
const result = await driver.execute(`SELECT count(*) AS count FROM ${table}`)
|
||||
private tableCount(driver: BetterSqlite3Driver, table: string): number {
|
||||
const result = driver.execute(`SELECT count(*) AS count FROM ${table}`)
|
||||
return Number(result.rows[0]?.count ?? 0)
|
||||
}
|
||||
|
||||
|
||||
@@ -59,10 +59,11 @@ export function createIndexDocumentsJobHandler(
|
||||
const { base, item } = input
|
||||
|
||||
// Mark reading before file/network IO so the UI reflects the current long-running phase.
|
||||
// No base mutation lock: this only writes the main app DB (knowledge_item), not the
|
||||
// per-base index.sqlite the lock protects, and updateStatus's own 'deleting' guard
|
||||
// (KnowledgeItemService.updateStatus) already covers the race the lock would.
|
||||
reportKnowledgeProgress(ctx, 0, { stage: 'reading', currentFile: 0, totalFiles: 1 })
|
||||
await knowledgeLockManager.withBaseMutationLock(ctx.input.baseId, () => {
|
||||
knowledgeItemService.updateStatus(ctx.input.itemId, 'reading')
|
||||
})
|
||||
knowledgeItemService.updateStatus(ctx.input.itemId, 'reading')
|
||||
|
||||
// Capture a url's or note's snapshot on first index (a url fetches outside
|
||||
// the lock, a note writes its in-hand content; both persist a relativePath
|
||||
@@ -83,10 +84,9 @@ export function createIndexDocumentsJobHandler(
|
||||
}
|
||||
|
||||
// Mark embedding separately so the UI reflects the current long-running phase.
|
||||
// No base mutation lock here either — same reasoning as the 'reading' status above.
|
||||
reportKnowledgeProgress(ctx, 40, { stage: 'embedding', currentFile: 0, totalFiles: 1 })
|
||||
await knowledgeLockManager.withBaseMutationLock(ctx.input.baseId, () =>
|
||||
knowledgeItemService.updateStatus(ctx.input.itemId, 'embedding')
|
||||
)
|
||||
knowledgeItemService.updateStatus(ctx.input.itemId, 'embedding')
|
||||
|
||||
// Use readableItem, not item: for a freshly captured url it carries the snapshot
|
||||
// relativePath, so the material's relative_path is the real `raw/` snapshot path
|
||||
|
||||
@@ -16,12 +16,14 @@ export async function deleteKnowledgeItemVectors(base: KnowledgeBase, itemIds: s
|
||||
return
|
||||
}
|
||||
|
||||
// Delete every id in ONE batched transaction with a single collectIndexGarbage pass.
|
||||
// Delete every id with a single collectIndexGarbage pass at the end (each material's
|
||||
// row delete is its own short transaction; see KnowledgeIndexStore.deleteMaterials).
|
||||
// The old per-id Promise.allSettled loop ran the two full-table GC scans once per item,
|
||||
// so deleting a folder of N files scanned the whole embedding/content table N times —
|
||||
// the multi-second main-process freeze on large (PDF-heavy) folders. deleteMaterials
|
||||
// rolls the whole batch back on failure (throwing the root cause), so a retry
|
||||
// re-discovers every affected id; no per-item failure aggregation is needed.
|
||||
// the multi-second main-process freeze on large (PDF-heavy) folders. A failure partway
|
||||
// leaves whatever was already deleted committed; that is safe because this call always
|
||||
// precedes the knowledge_item DB row deletion (see subtreePurge.ts), so a retry
|
||||
// re-discovers exactly the materials still left.
|
||||
await store.deleteMaterials(uniqueItemIds)
|
||||
}
|
||||
|
||||
|
||||
@@ -43,12 +43,12 @@ export function getKnowledgeBaseMetaDir(baseId: string): FilePath {
|
||||
return path.join(getKnowledgeBaseDir(baseId), CHERRY_META_DIR) as FilePath
|
||||
}
|
||||
|
||||
export async function getKnowledgeVectorStoreFilePath(baseId: string): Promise<FilePath> {
|
||||
const metaDir = getKnowledgeBaseMetaDir(baseId)
|
||||
await ensureDir(metaDir)
|
||||
return getKnowledgeVectorStoreFilePathSync(baseId)
|
||||
}
|
||||
|
||||
/**
|
||||
* The base's index.sqlite path. No `ensureDir` here — `openBetterSqlite3IndexDriver`
|
||||
* (BetterSqlite3Driver.ts) already `mkdirSync`s the parent `.cherry/` dir itself before
|
||||
* opening, so a caller that only needs the path (not a guaranteed-existing directory)
|
||||
* does not need to duplicate that work.
|
||||
*/
|
||||
export function getKnowledgeVectorStoreFilePathSync(baseId: string): FilePath {
|
||||
const metaDir = getKnowledgeBaseMetaDir(baseId)
|
||||
return path.join(metaDir, VECTOR_STORE_FILE) as FilePath
|
||||
|
||||
@@ -8,11 +8,7 @@ import type { CompletedKnowledgeBase, KnowledgeBase } from '@shared/data/types/k
|
||||
import { isCompletedKnowledgeBase } from '@shared/data/types/knowledge'
|
||||
|
||||
import { isIndexableKnowledgeItem } from '../utils/items'
|
||||
import {
|
||||
deleteKnowledgeBaseDir,
|
||||
getKnowledgeVectorStoreFilePath,
|
||||
getKnowledgeVectorStoreFilePathSync
|
||||
} from '../utils/storage/pathStorage'
|
||||
import { deleteKnowledgeBaseDir, getKnowledgeVectorStoreFilePathSync } from '../utils/storage/pathStorage'
|
||||
import { openBetterSqlite3IndexDriver } from './indexStore/BetterSqlite3Driver'
|
||||
import { betterSqlite3VectorIndex } from './indexStore/BetterSqlite3VectorIndex'
|
||||
import { ensureIndexMeta, hasAnyMaterial, readIndexSchemaVersion } from './indexStore/indexMeta'
|
||||
@@ -47,11 +43,11 @@ function assertVectorStoreReadyBase(base: KnowledgeBase): asserts base is Comple
|
||||
@Injectable('KnowledgeVectorStoreService')
|
||||
@ServicePhase(Phase.WhenReady)
|
||||
export class KnowledgeVectorStoreService extends BaseService {
|
||||
// Caches the in-flight open promise, not the resolved store, so concurrent
|
||||
// getIndexStore calls for the same base share one open (one better-sqlite3
|
||||
// connection) instead of racing — the loser of a "resolve then set" race would
|
||||
// otherwise leak an orphaned store that no one ever closes.
|
||||
private instanceCache = new Map<string, Promise<KnowledgeIndexStore>>()
|
||||
// Opening a store (better-sqlite3 connect + schema + meta) is fully synchronous
|
||||
// (see openIndexStore), so it runs to completion in one JS turn — no concurrent
|
||||
// getIndexStore call for the same base can ever observe an in-flight open, and a
|
||||
// failed open never gets cached (the throw happens before .set() below runs).
|
||||
private instanceCache = new Map<string, KnowledgeIndexStore>()
|
||||
|
||||
/** Open (or reuse) the base's index store, ensuring its schema exists. */
|
||||
async getIndexStore(base: KnowledgeBase): Promise<KnowledgeIndexStore> {
|
||||
@@ -63,20 +59,10 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
return cached
|
||||
}
|
||||
|
||||
const opening = this.openIndexStore(base)
|
||||
this.instanceCache.set(base.id, opening)
|
||||
try {
|
||||
const store = await opening
|
||||
logger.info('Opened knowledge index store', { baseId: base.id, cacheSize: this.instanceCache.size })
|
||||
return store
|
||||
} catch (error) {
|
||||
// Evict the rejected promise so a later call retries the open instead of
|
||||
// forever re-awaiting the same failure (only if it is still the cached one).
|
||||
if (this.instanceCache.get(base.id) === opening) {
|
||||
this.instanceCache.delete(base.id)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
const store = this.openIndexStore(base)
|
||||
this.instanceCache.set(base.id, store)
|
||||
logger.info('Opened knowledge index store', { baseId: base.id, cacheSize: this.instanceCache.size })
|
||||
return store
|
||||
}
|
||||
|
||||
/** Reuse or open the store only if its file already exists on disk; used by cleanup paths that must not create one. */
|
||||
@@ -105,12 +91,12 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
* and `index.sqlite` alike. Only safe when deleting the whole base.
|
||||
*/
|
||||
async deleteStore(baseId: string): Promise<void> {
|
||||
const opening = this.instanceCache.get(baseId)
|
||||
const store = this.instanceCache.get(baseId)
|
||||
|
||||
try {
|
||||
await this.closeStoreInstance(opening)
|
||||
await this.closeStoreInstance(store)
|
||||
await deleteKnowledgeBaseDir(baseId)
|
||||
logger.info('Deleted knowledge index store', { baseId, hadCachedStore: Boolean(opening) })
|
||||
logger.info('Deleted knowledge index store', { baseId, hadCachedStore: Boolean(store) })
|
||||
} finally {
|
||||
this.instanceCache.delete(baseId)
|
||||
}
|
||||
@@ -121,9 +107,9 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
logger.info('Stopping knowledge index stores', { storeCount })
|
||||
|
||||
try {
|
||||
for (const [baseId, opening] of this.instanceCache.entries()) {
|
||||
for (const [baseId, store] of this.instanceCache.entries()) {
|
||||
try {
|
||||
await this.closeStoreInstance(opening)
|
||||
await this.closeStoreInstance(store)
|
||||
} catch (error) {
|
||||
logger.error('Failed to close knowledge index store', error as Error, { baseId })
|
||||
}
|
||||
@@ -134,9 +120,9 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
private async openIndexStore(base: CompletedKnowledgeBase): Promise<KnowledgeIndexStore> {
|
||||
const dbPath = await getKnowledgeVectorStoreFilePath(base.id)
|
||||
const driver = await openBetterSqlite3IndexDriver(dbPath)
|
||||
private openIndexStore(base: CompletedKnowledgeBase): KnowledgeIndexStore {
|
||||
const dbPath = getKnowledgeVectorStoreFilePathSync(base.id)
|
||||
const driver = openBetterSqlite3IndexDriver(dbPath)
|
||||
try {
|
||||
// An index.sqlite from an older schema layout cannot be migrated in place —
|
||||
// `CREATE ... IF NOT EXISTS` never retrofits a new column/trigger onto an
|
||||
@@ -147,28 +133,28 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
// (A stale-version file swapped in from another base is rebuilt here rather than
|
||||
// refused by the base_id check below — but the reset drops its rows, so no other
|
||||
// base's data is ever served; only the explicit refusal diagnostic is skipped.)
|
||||
const storedVersion = await readIndexSchemaVersion(driver)
|
||||
const storedVersion = readIndexSchemaVersion(driver)
|
||||
if (storedVersion !== null && storedVersion !== KNOWLEDGE_INDEX_SCHEMA_VERSION) {
|
||||
logger.warn('Knowledge index schema version mismatch — rebuilding the derived index', {
|
||||
baseId: base.id,
|
||||
storedVersion,
|
||||
expectedVersion: KNOWLEDGE_INDEX_SCHEMA_VERSION
|
||||
})
|
||||
await resetKnowledgeIndexSchema(driver)
|
||||
resetKnowledgeIndexSchema(driver)
|
||||
} else {
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
createKnowledgeIndexSchema(driver)
|
||||
}
|
||||
// Stamp + verify the meta identity row before handing out the store,
|
||||
// so an index.sqlite swapped in from another base is rejected here (§4.1).
|
||||
// That is the only refusal — a blank/recreated file is stamped as fresh and
|
||||
// mounts empty; reportInvisibleIndexContents below makes that state loud.
|
||||
await ensureIndexMeta(driver, { baseId: base.id })
|
||||
await this.reportInvisibleIndexContents(driver, base.id)
|
||||
ensureIndexMeta(driver, { baseId: base.id })
|
||||
this.reportInvisibleIndexContents(driver, base.id)
|
||||
return new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
|
||||
} catch (error) {
|
||||
// Close the driver opened above so a failed open never leaks the index file
|
||||
// handle (which on Windows would later block deleting the base dir).
|
||||
await driver.close()
|
||||
driver.close()
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@@ -186,8 +172,8 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
* NOT_FOUND is what turns that into a loud failure instead of a cached
|
||||
* forever-empty store).
|
||||
*/
|
||||
private async reportInvisibleIndexContents(driver: SqliteDriver, baseId: string): Promise<void> {
|
||||
if (await hasAnyMaterial(driver)) {
|
||||
private reportInvisibleIndexContents(driver: SqliteDriver, baseId: string): void {
|
||||
if (hasAnyMaterial(driver)) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -212,14 +198,7 @@ export class KnowledgeVectorStoreService extends BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
private async closeStoreInstance(opening: Promise<KnowledgeIndexStore> | undefined): Promise<void> {
|
||||
if (!opening) {
|
||||
return
|
||||
}
|
||||
// A store that never opened needs no close (the open path already closed its
|
||||
// driver on failure) — swallow the rejection here instead of re-throwing the
|
||||
// open error into an unrelated delete/shutdown operation.
|
||||
const store = await opening.catch(() => undefined)
|
||||
private async closeStoreInstance(store: KnowledgeIndexStore | undefined): Promise<void> {
|
||||
if (!store) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ const {
|
||||
hasAnyMaterialMock,
|
||||
getItemsByBaseIdMock,
|
||||
indexStoreCtorMock,
|
||||
getPathMock,
|
||||
getPathSyncMock,
|
||||
deleteDirMock,
|
||||
statMock
|
||||
@@ -36,7 +35,6 @@ const {
|
||||
hasAnyMaterialMock: vi.fn(),
|
||||
getItemsByBaseIdMock: vi.fn(),
|
||||
indexStoreCtorMock: vi.fn(),
|
||||
getPathMock: vi.fn(),
|
||||
getPathSyncMock: vi.fn(),
|
||||
deleteDirMock: vi.fn(),
|
||||
statMock: vi.fn()
|
||||
@@ -101,7 +99,6 @@ vi.mock('@data/services/KnowledgeItemService', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('../../utils/storage/pathStorage', () => ({
|
||||
getKnowledgeVectorStoreFilePath: getPathMock,
|
||||
getKnowledgeVectorStoreFilePathSync: getPathSyncMock,
|
||||
deleteKnowledgeBaseDir: deleteDirMock
|
||||
}))
|
||||
@@ -136,22 +133,24 @@ function lastStore() {
|
||||
describe('KnowledgeVectorStoreService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
getPathMock.mockImplementation(async (baseId: string) => `/tmp/${baseId}/index.sqlite`)
|
||||
getPathSyncMock.mockImplementation((baseId: string) => `/tmp/${baseId}/index.sqlite`)
|
||||
// Each open returns a fresh closeable driver so failure paths can assert close().
|
||||
openDriverMock.mockImplementation(async () => ({
|
||||
// The real driver port is synchronous (see BetterSqlite3Driver) — these mocks
|
||||
// must return plain values, not promises, or the store under test would try to
|
||||
// call .execute() on a Promise object.
|
||||
openDriverMock.mockImplementation(() => ({
|
||||
kind: 'driver',
|
||||
close: vi.fn().mockResolvedValue(undefined)
|
||||
close: vi.fn()
|
||||
}))
|
||||
createSchemaMock.mockResolvedValue(undefined)
|
||||
resetSchemaMock.mockResolvedValue(undefined)
|
||||
ensureIndexMetaMock.mockResolvedValue(undefined)
|
||||
createSchemaMock.mockReturnValue(undefined)
|
||||
resetSchemaMock.mockReturnValue(undefined)
|
||||
ensureIndexMetaMock.mockReturnValue(undefined)
|
||||
// Default: a fresh/blank file has no stored version → the open path takes the normal
|
||||
// create branch (no rebuild). Mismatch tests override this per-case.
|
||||
readIndexSchemaVersionMock.mockResolvedValue(null)
|
||||
readIndexSchemaVersionMock.mockReturnValue(null)
|
||||
// A non-empty material probe keeps the invisible-contents diagnostic quiet
|
||||
// unless a test opts in.
|
||||
hasAnyMaterialMock.mockResolvedValue(true)
|
||||
hasAnyMaterialMock.mockReturnValue(true)
|
||||
getItemsByBaseIdMock.mockReturnValue([])
|
||||
deleteDirMock.mockResolvedValue(undefined)
|
||||
indexStoreCtorMock.mockImplementation(() => ({ close: vi.fn().mockResolvedValue(undefined) }))
|
||||
@@ -189,7 +188,9 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
it('evicts a failed open so a later call retries instead of re-awaiting the failure', async () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
openDriverMock.mockRejectedValueOnce(new Error('open failed'))
|
||||
openDriverMock.mockImplementationOnce(() => {
|
||||
throw new Error('open failed')
|
||||
})
|
||||
|
||||
await expect(service.getIndexStore(base)).rejects.toThrow('open failed')
|
||||
|
||||
@@ -214,7 +215,7 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
it('creates the schema normally when the stored version matches (no rebuild)', async () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION)
|
||||
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION)
|
||||
|
||||
await service.getIndexStore(base)
|
||||
|
||||
@@ -225,7 +226,7 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
it('rebuilds the derived index when an existing index.sqlite is at a stale schema version', async () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION - 1)
|
||||
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION - 1)
|
||||
|
||||
await service.getIndexStore(base)
|
||||
|
||||
@@ -252,7 +253,7 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
// future refactor to `<` does not silently start mounting newer files.
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
readIndexSchemaVersionMock.mockResolvedValueOnce(MOCK_SCHEMA_VERSION + 1)
|
||||
readIndexSchemaVersionMock.mockReturnValueOnce(MOCK_SCHEMA_VERSION + 1)
|
||||
|
||||
await service.getIndexStore(base)
|
||||
|
||||
@@ -264,11 +265,13 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
|
||||
openDriverMock.mockImplementationOnce(async () => {
|
||||
openDriverMock.mockImplementationOnce(() => {
|
||||
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
|
||||
return openedDriver
|
||||
})
|
||||
ensureIndexMetaMock.mockRejectedValueOnce(new Error('belongs to a different base'))
|
||||
ensureIndexMetaMock.mockImplementationOnce(() => {
|
||||
throw new Error('belongs to a different base')
|
||||
})
|
||||
|
||||
await expect(service.getIndexStore(base)).rejects.toThrow('belongs to a different base')
|
||||
|
||||
@@ -280,11 +283,13 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
|
||||
openDriverMock.mockImplementationOnce(async () => {
|
||||
openDriverMock.mockImplementationOnce(() => {
|
||||
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
|
||||
return openedDriver
|
||||
})
|
||||
createSchemaMock.mockRejectedValueOnce(new Error('disk full'))
|
||||
createSchemaMock.mockImplementationOnce(() => {
|
||||
throw new Error('disk full')
|
||||
})
|
||||
|
||||
await expect(service.getIndexStore(base)).rejects.toThrow('disk full')
|
||||
|
||||
@@ -345,23 +350,19 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
expect(indexStoreCtorMock).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('deleteStore proceeds past a rejected in-flight open instead of re-throwing it', async () => {
|
||||
it('deleteStore removes the directory even when no store was ever opened for the base', async () => {
|
||||
// Opening a store (see openIndexStore) is fully synchronous — it either completes and
|
||||
// caches a store, or throws before caching anything. There is no longer an in-flight
|
||||
// state deleteStore could observe mid-open, so this covers the remaining "nothing
|
||||
// cached" case: deleteStore must still close-if-present (a no-op here) and remove
|
||||
// the directory.
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
let rejectOpen: (error: Error) => void = () => {}
|
||||
openDriverMock.mockImplementationOnce(() => new Promise((_, reject) => (rejectOpen = reject)))
|
||||
|
||||
// deleteStore grabs the still-pending open; when that open later fails, the
|
||||
// delete must not inherit the open error — a store that never opened needs
|
||||
// no close, and the directory removal has to go ahead.
|
||||
const opening = service.getIndexStore(base)
|
||||
const deleting = service.deleteStore(base.id)
|
||||
await vi.waitFor(() => expect(openDriverMock).toHaveBeenCalled())
|
||||
rejectOpen(new Error('open failed'))
|
||||
await service.deleteStore(base.id)
|
||||
|
||||
await expect(opening).rejects.toThrow('open failed')
|
||||
await expect(deleting).resolves.toBeUndefined()
|
||||
expect(deleteDirMock).toHaveBeenCalledWith(base.id)
|
||||
expect(indexStoreCtorMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('evicts the cached store even when directory removal fails', async () => {
|
||||
@@ -452,7 +453,7 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
it('logs an error when an empty index mounts under a base with completed items', async () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
hasAnyMaterialMock.mockResolvedValueOnce(false)
|
||||
hasAnyMaterialMock.mockReturnValueOnce(false)
|
||||
getItemsByBaseIdMock.mockReturnValueOnce([
|
||||
{ id: 'item-1', type: 'directory', status: 'completed' },
|
||||
{ id: 'item-2', type: 'file', status: 'completed' }
|
||||
@@ -470,7 +471,7 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
it('stays quiet when an empty index mounts under a base with no completed indexable items', async () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
hasAnyMaterialMock.mockResolvedValueOnce(false)
|
||||
hasAnyMaterialMock.mockReturnValueOnce(false)
|
||||
// A completed empty directory is legitimate without materials; in-flight leaves are too.
|
||||
getItemsByBaseIdMock.mockReturnValueOnce([
|
||||
{ id: 'item-1', type: 'directory', status: 'completed' },
|
||||
@@ -486,11 +487,11 @@ describe('KnowledgeVectorStoreService', () => {
|
||||
const service = new KnowledgeVectorStoreService()
|
||||
const base = createBase()
|
||||
let openedDriver: { close: ReturnType<typeof vi.fn> } | undefined
|
||||
openDriverMock.mockImplementationOnce(async () => {
|
||||
openDriverMock.mockImplementationOnce(() => {
|
||||
openedDriver = { kind: 'driver', close: vi.fn().mockResolvedValue(undefined) } as never
|
||||
return openedDriver
|
||||
})
|
||||
hasAnyMaterialMock.mockResolvedValueOnce(false)
|
||||
hasAnyMaterialMock.mockReturnValueOnce(false)
|
||||
getItemsByBaseIdMock.mockImplementationOnce(() => {
|
||||
throw new Error('app database unavailable')
|
||||
})
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import { mkdirSync } from 'node:fs'
|
||||
import { dirname } from 'node:path'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { toAsarUnpackedPath } from '@main/utils/asar'
|
||||
import Database from 'better-sqlite3'
|
||||
import { getLoadablePath } from 'sqlite-vec'
|
||||
|
||||
import type { SqliteDriver, SqliteReclaimOutcome, SqliteTransaction, SqlQueryResult, SqlValue } from './types'
|
||||
|
||||
const logger = loggerService.withContext('BetterSqlite3Driver')
|
||||
|
||||
/**
|
||||
* VACUUM in {@link BetterSqlite3Driver.reclaim} only when the freelist is BOTH a large
|
||||
* fraction of the file AND past an absolute floor. The fraction skips rewriting a
|
||||
@@ -44,11 +41,15 @@ function toBindable(value: SqlValue): Bindable {
|
||||
* No internal write mutex: every writer (rebuildMaterial / deleteMaterials /
|
||||
* reclaimSpace) runs under `KnowledgeLockManager.withBaseMutationLock(baseId)`, a
|
||||
* per-base async-mutex that already serializes all writes to a base's index — so the
|
||||
* driver never self-serializes. better-sqlite3's single connection makes a manual
|
||||
* BEGIN IMMEDIATE transaction inherently atomic; reads run on the same connection and
|
||||
* ride WAL / busy_timeout. (Holding BEGIN across a transaction callback's own
|
||||
* event-loop yield is safe here precisely because the base lock blocks any other
|
||||
* writer from issuing a nested BEGIN.)
|
||||
* driver never self-serializes. better-sqlite3's single connection makes
|
||||
* `transaction` inherently atomic; reads run on the same connection and ride WAL /
|
||||
* busy_timeout. `transaction` uses better-sqlite3's native `db.transaction(fn).immediate`
|
||||
* (BEGIN IMMEDIATE, matching the old libsql 'write' tx) rather than hand-rolled
|
||||
* BEGIN/COMMIT/ROLLBACK: it runs `fn` synchronously to completion in one JS turn, so
|
||||
* BEGIN and COMMIT can never straddle an event-loop yield — no unrelated read on this
|
||||
* connection can ever observe a transaction mid-flight. It also throws `TypeError` if
|
||||
* `fn` returns a Promise, turning an accidental async callback into a loud failure
|
||||
* instead of a silent early commit.
|
||||
*
|
||||
* NOTE (future KB-owner cleanup): the driver trusts callers to hold the base lock and does not
|
||||
* self-check it. A cheap `transactionActive` reentrancy assertion in {@link transaction} could
|
||||
@@ -60,44 +61,30 @@ export class BetterSqlite3Driver implements SqliteDriver {
|
||||
|
||||
constructor(private readonly db: Database.Database) {}
|
||||
|
||||
async execute(sql: string, args: SqlValue[] = []): Promise<SqlQueryResult> {
|
||||
execute(sql: string, args: SqlValue[] = []): SqlQueryResult {
|
||||
this.assertOpen()
|
||||
const stmt = this.db.prepare(sql)
|
||||
const bound = args.map(toBindable)
|
||||
// `reader` is true for statements that yield rows (SELECT, row-returning PRAGMA).
|
||||
// run() throws on those and all() throws on non-row statements, so split on it.
|
||||
if (stmt.reader) {
|
||||
return { rows: stmt.all(...bound) as Array<Record<string, SqlValue>> }
|
||||
return { rows: stmt.all(...bound) as Array<Record<string, SqlValue>>, changes: 0 }
|
||||
}
|
||||
stmt.run(...bound)
|
||||
return { rows: [] }
|
||||
const result = stmt.run(...bound)
|
||||
return { rows: [], changes: result.changes }
|
||||
}
|
||||
|
||||
async transaction<T>(fn: (tx: SqliteTransaction) => Promise<T>): Promise<T> {
|
||||
transaction<T>(fn: (tx: SqliteTransaction) => T): T {
|
||||
this.assertOpen()
|
||||
const handle: SqliteTransaction = {
|
||||
execute: (sql, args) => this.execute(sql, args)
|
||||
}
|
||||
// BEGIN IMMEDIATE acquires the write lock up front (so a read-then-write transaction
|
||||
// never fails mid-way after the read already ran), matching the old libsql 'write' tx.
|
||||
this.db.exec('BEGIN IMMEDIATE')
|
||||
try {
|
||||
const result = await fn(handle)
|
||||
this.db.exec('COMMIT')
|
||||
return result
|
||||
} catch (error) {
|
||||
// Roll back, but never let a rollback failure mask the original error that
|
||||
// triggered it — that original is what callers need to diagnose the write.
|
||||
try {
|
||||
this.db.exec('ROLLBACK')
|
||||
} catch (rollbackError) {
|
||||
logger.warn('Failed to roll back knowledge index store transaction after an error', rollbackError as Error)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
// .immediate acquires the write lock up front (BEGIN IMMEDIATE), so a
|
||||
// read-then-write transaction never fails mid-way after the read already ran.
|
||||
return this.db.transaction((tx: SqliteTransaction) => fn(tx)).immediate(handle)
|
||||
}
|
||||
|
||||
async reclaim(preVacuumStatements: readonly string[] = []): Promise<SqliteReclaimOutcome> {
|
||||
reclaim(preVacuumStatements: readonly string[] = []): SqliteReclaimOutcome {
|
||||
this.assertOpen()
|
||||
// Checkpoint first: frees the WAL (cheap) and folds committed frees from the
|
||||
// delete into the main file so freelist_count reflects them. The base lock
|
||||
@@ -144,7 +131,7 @@ export class BetterSqlite3Driver implements SqliteDriver {
|
||||
}
|
||||
|
||||
/** Idempotent: a second close() (e.g. shutdown after an explicit deleteStore) is a no-op. */
|
||||
async close(): Promise<void> {
|
||||
close(): void {
|
||||
if (this.closed) {
|
||||
return
|
||||
}
|
||||
@@ -169,7 +156,7 @@ export class BetterSqlite3Driver implements SqliteDriver {
|
||||
* DbService PRAGMA setup. better-sqlite3 keeps one connection, so each PRAGMA is set
|
||||
* once and holds for the connection's lifetime — no replay machinery is needed.
|
||||
*/
|
||||
export async function openBetterSqlite3IndexDriver(filePath: string): Promise<BetterSqlite3Driver> {
|
||||
export function openBetterSqlite3IndexDriver(filePath: string): BetterSqlite3Driver {
|
||||
// better-sqlite3 creates the database FILE but not its parent directory (unlike libsql's
|
||||
// file: URL client), so ensure the base's index dir exists before opening.
|
||||
mkdirSync(dirname(filePath), { recursive: true })
|
||||
|
||||
@@ -17,11 +17,11 @@ const RRF_K = 60
|
||||
const EMBEDDING_HASH_QUERY_BATCH = 500
|
||||
|
||||
/**
|
||||
* How long {@link KnowledgeIndexStore.deleteMaterials} may run its per-material
|
||||
* row deletes before handing the main-process event loop back to the OS message
|
||||
* pump (see the method doc for why). Tuned well under the multi-second window
|
||||
* that surfaces the macOS beachball, while large enough that the yields add no
|
||||
* measurable overhead to a small delete.
|
||||
* How long {@link KnowledgeIndexStore.deleteMaterials} may run consecutive
|
||||
* per-material transactions before handing the main-process event loop back to
|
||||
* the OS message pump (see the method doc for why). Tuned well under the
|
||||
* multi-second window that surfaces the macOS beachball, while large enough
|
||||
* that the yields add no measurable overhead to a small delete.
|
||||
*/
|
||||
const DELETE_YIELD_BUDGET_MS = 50
|
||||
|
||||
@@ -70,16 +70,23 @@ export class KnowledgeIndexStore {
|
||||
}
|
||||
})
|
||||
|
||||
await this.driver.transaction(async (tx) => {
|
||||
this.driver.transaction((tx) => {
|
||||
// 0. Capture the material's prior content hash (undefined if it doesn't exist
|
||||
// yet) so step 8 can tell whether this rebuild could possibly have orphaned
|
||||
// anything, without an extra full-table scan.
|
||||
const priorRow = tx.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, [materialId])
|
||||
.rows[0]
|
||||
const priorContentHash = priorRow === undefined ? undefined : (priorRow.current_content_hash as string | null)
|
||||
|
||||
// 1. Content is immutable by hash — keep the existing row if present.
|
||||
await tx.execute(`INSERT OR IGNORE INTO content (content_hash, text, created_at) VALUES (?, ?, ?)`, [
|
||||
tx.execute(`INSERT OR IGNORE INTO content (content_hash, text, created_at) VALUES (?, ?, ?)`, [
|
||||
contentHash,
|
||||
input.content.text,
|
||||
now
|
||||
])
|
||||
|
||||
// 2. Upsert the material (current_content_hash set in step 7).
|
||||
await tx.execute(
|
||||
tx.execute(
|
||||
`INSERT INTO material (material_id, relative_path, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(material_id) DO UPDATE SET
|
||||
@@ -92,12 +99,12 @@ export class KnowledgeIndexStore {
|
||||
// FK to search_unit (its target_id is polymorphic), so it is deleted
|
||||
// explicitly while search_unit still exists to resolve the targets; the
|
||||
// FTS index is kept in sync by the search_text delete trigger.
|
||||
await this.deleteMaterialSearchText(tx, materialId)
|
||||
await tx.execute(`DELETE FROM search_unit WHERE material_id = ?`, [materialId])
|
||||
this.deleteMaterialSearchText(tx, materialId)
|
||||
const deletedUnits = tx.execute(`DELETE FROM search_unit WHERE material_id = ?`, [materialId]).changes
|
||||
|
||||
// 4 & 5. Insert new units and their body search_text (FTS synced by trigger).
|
||||
for (const unit of units) {
|
||||
await tx.execute(
|
||||
tx.execute(
|
||||
`INSERT INTO search_unit
|
||||
(unit_id, material_id, content_hash, unit_type, unit_index, title, char_start, char_end, locator_json, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
@@ -116,7 +123,7 @@ export class KnowledgeIndexStore {
|
||||
now
|
||||
]
|
||||
)
|
||||
await tx.execute(
|
||||
tx.execute(
|
||||
`INSERT INTO search_text (search_text_id, target_type, target_id, kind, text, embedding_text_hash, created_at)
|
||||
VALUES (?, 'search_unit', ?, 'body', ?, ?, ?)`,
|
||||
[
|
||||
@@ -131,10 +138,11 @@ export class KnowledgeIndexStore {
|
||||
|
||||
// 6. Insert missing embeddings; existing hashes are reused (decision A4).
|
||||
for (const embedding of input.embeddings) {
|
||||
await tx.execute(
|
||||
`INSERT OR IGNORE INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`,
|
||||
[embedding.embeddingTextHash, encodeVectorBlob(embedding.vector), now]
|
||||
)
|
||||
tx.execute(`INSERT OR IGNORE INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
|
||||
embedding.embeddingTextHash,
|
||||
encodeVectorBlob(embedding.vector),
|
||||
now
|
||||
])
|
||||
}
|
||||
|
||||
// 6b. Coverage check: every unit's re-derived embedding hash must resolve to a
|
||||
@@ -147,70 +155,89 @@ export class KnowledgeIndexStore {
|
||||
// drop a hash it reported present before this rebuild writes, and the job
|
||||
// then skips re-embedding it. Failing loud rolls back; the job's retry
|
||||
// re-reads (the hash is now absent), re-embeds it, and converges.
|
||||
await this.assertEmbeddingCoverage(tx, materialId, [...new Set(units.map((unit) => unit.embeddingTextHash))])
|
||||
this.assertEmbeddingCoverage(tx, materialId, [...new Set(units.map((unit) => unit.embeddingTextHash))])
|
||||
|
||||
// 7. Mark the material's current content (failure/lifecycle state is the
|
||||
// authority of knowledge_item, not this derived index).
|
||||
await tx.execute(`UPDATE material SET current_content_hash = ?, updated_at = ? WHERE material_id = ?`, [
|
||||
tx.execute(`UPDATE material SET current_content_hash = ?, updated_at = ? WHERE material_id = ?`, [
|
||||
contentHash,
|
||||
now,
|
||||
materialId
|
||||
])
|
||||
|
||||
// 8. Sweep rows this rebuild orphaned (old units' embeddings, old content the
|
||||
// new revision no longer references). Safe under the base mutation lock.
|
||||
await this.collectIndexGarbage(tx)
|
||||
// new revision no longer references) — but only when it actually could have
|
||||
// orphaned something. A first-time create (no prior row) or a rebuild that
|
||||
// replaced zero old units AND kept the same content hash touches nothing an
|
||||
// earlier revision referenced, so the GC's full-table anti-join scans would
|
||||
// find nothing; skipping them turns a bulk index of K materials from
|
||||
// O(K × table) into O(K). Checking unit deletions alone is not sound — a
|
||||
// material that previously had zero units but a different content hash would
|
||||
// slip through and leave that old content row an orphan — so both conditions
|
||||
// are required.
|
||||
const contentChanged = priorContentHash !== undefined && priorContentHash !== contentHash
|
||||
if (deletedUnits > 0 || contentChanged) {
|
||||
this.collectIndexGarbage(tx)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete many materials in ONE transaction, sweeping orphaned `embedding` /
|
||||
* `content` rows with a SINGLE {@link collectIndexGarbage} pass at the end.
|
||||
* Delete many materials — each in its OWN short transaction — then sweep
|
||||
* orphaned `embedding` / `content` rows with a SINGLE {@link collectIndexGarbage}
|
||||
* pass in a final transaction.
|
||||
*
|
||||
* Removing each material row cascades to its `search_unit`; the units' body
|
||||
* `search_text` is deleted explicitly first (no FK), which also clears the FTS
|
||||
* index via the delete trigger.
|
||||
*
|
||||
* collectIndexGarbage runs two FULL-TABLE anti-join scans, so calling it once
|
||||
* per material (the old per-material delete loop) made a bulk delete
|
||||
* per material (an old per-material delete+GC loop) made a bulk delete
|
||||
* O(materials × table): deleting a folder of N files scanned the whole
|
||||
* `embedding`/`content` table N times. With a large index (e.g. a folder of
|
||||
* PDFs chunked into tens of thousands of rows) that blocked the main-process
|
||||
* event loop for seconds — the folder-delete UI freeze. Deleting the rows up
|
||||
* front and GCing once makes it O(N + table), and the single transaction makes
|
||||
* the bulk delete atomic (a failure rolls the whole batch back so a retry
|
||||
* re-discovers every affected id).
|
||||
* front and GCing once makes it O(N + table).
|
||||
*
|
||||
* Batching the GC removes the super-linear cost, but the per-material row
|
||||
* deletes are still linear in chunks: each `search_text` delete fires the FTS
|
||||
* delete trigger, which the driver runs synchronously on the main process.
|
||||
* Tens of thousands of rows still sum to a multi-second block, and
|
||||
* because Electron drives the window from this same loop that block IS the
|
||||
* macOS beachball (the renderer thread never stalls). So the loop yields to the
|
||||
* OS message pump whenever it has run for {@link DELETE_YIELD_BUDGET_MS}: the
|
||||
* total work is unchanged, but no single uninterrupted block is long enough to
|
||||
* freeze the window. Yielding mid-transaction is safe — the caller holds the
|
||||
* base mutation lock, so no other writer is waiting on this base's index.
|
||||
* Tens of thousands of rows still sum to a multi-second block, and because
|
||||
* Electron drives the window from this same loop that block IS the macOS
|
||||
* beachball (the renderer thread never stalls). A driver transaction must run
|
||||
* fully synchronously (no event-loop yield inside `BEGIN`..`COMMIT` — see
|
||||
* {@link SqliteDriver.transaction}), so each material gets its own transaction
|
||||
* and the loop yields to the OS message pump BETWEEN them whenever it has run
|
||||
* for {@link DELETE_YIELD_BUDGET_MS}: the total work is unchanged, but no single
|
||||
* uninterrupted block is long enough to freeze the window.
|
||||
*
|
||||
* This is no longer one all-or-nothing batch — a failure partway leaves the
|
||||
* materials deleted so far committed. That is safe: every caller (subtreePurge.ts)
|
||||
* deletes vectors before the corresponding `knowledge_item` DB rows, so those rows
|
||||
* still exist after a partial failure and a retry re-discovers exactly the
|
||||
* materials still left (re-deleting an already-gone one is a harmless no-op).
|
||||
*/
|
||||
async deleteMaterials(materialIds: string[]): Promise<void> {
|
||||
const uniqueMaterialIds = [...new Set(materialIds)]
|
||||
if (uniqueMaterialIds.length === 0) {
|
||||
return
|
||||
}
|
||||
await this.driver.transaction(async (tx) => {
|
||||
// performance.now() is monotonic — a wall-clock step (NTP/manual) mid-batch
|
||||
// must not make the delta negative and silently disable the yields for the
|
||||
// rest of a large delete, reintroducing the freeze this loop prevents.
|
||||
let lastYieldAt = performance.now()
|
||||
for (const materialId of uniqueMaterialIds) {
|
||||
await this.deleteMaterialSearchText(tx, materialId)
|
||||
await tx.execute(`DELETE FROM material WHERE material_id = ?`, [materialId])
|
||||
if (performance.now() - lastYieldAt >= DELETE_YIELD_BUDGET_MS) {
|
||||
await new Promise<void>((resolve) => setImmediate(resolve))
|
||||
lastYieldAt = performance.now()
|
||||
}
|
||||
// performance.now() is monotonic — a wall-clock step (NTP/manual) mid-batch
|
||||
// must not make the delta negative and silently disable the yields for the
|
||||
// rest of a large delete, reintroducing the freeze this loop prevents.
|
||||
let lastYieldAt = performance.now()
|
||||
for (const materialId of uniqueMaterialIds) {
|
||||
this.driver.transaction((tx) => {
|
||||
this.deleteMaterialSearchText(tx, materialId)
|
||||
tx.execute(`DELETE FROM material WHERE material_id = ?`, [materialId])
|
||||
})
|
||||
if (performance.now() - lastYieldAt >= DELETE_YIELD_BUDGET_MS) {
|
||||
await new Promise<void>((resolve) => setImmediate(resolve))
|
||||
lastYieldAt = performance.now()
|
||||
}
|
||||
await this.collectIndexGarbage(tx)
|
||||
}
|
||||
this.driver.transaction((tx) => {
|
||||
this.collectIndexGarbage(tx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -224,12 +251,12 @@ export class KnowledgeIndexStore {
|
||||
* `search_unit.content_hash` (FK CASCADE) reference it — both referrers are
|
||||
* excluded, so the delete never violates either constraint.
|
||||
*/
|
||||
private async collectIndexGarbage(tx: SqliteTransaction): Promise<void> {
|
||||
await tx.execute(
|
||||
private collectIndexGarbage(tx: SqliteTransaction): void {
|
||||
tx.execute(
|
||||
`DELETE FROM embedding
|
||||
WHERE NOT EXISTS (SELECT 1 FROM search_text st WHERE st.embedding_text_hash = embedding.embedding_text_hash)`
|
||||
)
|
||||
await tx.execute(
|
||||
tx.execute(
|
||||
`DELETE FROM content
|
||||
WHERE NOT EXISTS (SELECT 1 FROM material m WHERE m.current_content_hash = content.content_hash)
|
||||
AND NOT EXISTS (SELECT 1 FROM search_unit su WHERE su.content_hash = content.content_hash)`
|
||||
@@ -256,7 +283,7 @@ export class KnowledgeIndexStore {
|
||||
for (let i = 0; i < hashes.length; i += EMBEDDING_HASH_QUERY_BATCH) {
|
||||
const batch = hashes.slice(i, i + EMBEDDING_HASH_QUERY_BATCH)
|
||||
const placeholders = batch.map(() => '?').join(', ')
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT embedding_text_hash FROM embedding WHERE embedding_text_hash IN (${placeholders})`,
|
||||
batch
|
||||
)
|
||||
@@ -269,7 +296,7 @@ export class KnowledgeIndexStore {
|
||||
|
||||
/** Read back a material's units (with body text), ordered by unit index. */
|
||||
async listMaterialUnits(materialId: string): Promise<KnowledgeSearchUnit[]> {
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT su.unit_id, su.material_id, su.unit_type, su.unit_index, su.title, su.char_start, su.char_end, st.text AS body
|
||||
FROM search_unit su
|
||||
LEFT JOIN search_text st
|
||||
@@ -307,10 +334,9 @@ export class KnowledgeIndexStore {
|
||||
* the resolved material against the visible knowledge_item before reading.
|
||||
*/
|
||||
async getMaterialByRelativePath(relativePath: string): Promise<KnowledgeMaterialRef | null> {
|
||||
const result = await this.driver.execute(
|
||||
`SELECT material_id, relative_path FROM material WHERE relative_path = ?`,
|
||||
[relativePath]
|
||||
)
|
||||
const result = this.driver.execute(`SELECT material_id, relative_path FROM material WHERE relative_path = ?`, [
|
||||
relativePath
|
||||
])
|
||||
const row = result.rows[0]
|
||||
if (!row) {
|
||||
return null
|
||||
@@ -329,7 +355,7 @@ export class KnowledgeIndexStore {
|
||||
* holds (the same invariant rebuildMaterial enforces at write time).
|
||||
*/
|
||||
async readMaterialContent(materialId: string): Promise<string | null> {
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT c.text AS text
|
||||
FROM material m
|
||||
JOIN content c ON c.content_hash = m.current_content_hash
|
||||
@@ -360,10 +386,10 @@ export class KnowledgeIndexStore {
|
||||
|
||||
const alpha = input.alpha ?? 0.5
|
||||
const prefetch = input.topK * 5
|
||||
const [vector, bm25] = await Promise.all([
|
||||
this.vectorSearch(this.requireQueryEmbedding(input), prefetch),
|
||||
this.bm25Search(input.queryText, prefetch)
|
||||
])
|
||||
// Both lanes are synchronous SQL over the same connection — sequential, not
|
||||
// Promise.all: the driver has no true concurrency to parallelize.
|
||||
const vector = this.vectorSearch(this.requireQueryEmbedding(input), prefetch)
|
||||
const bm25 = this.bm25Search(input.queryText, prefetch)
|
||||
return fuseWithRrf(vector, bm25, alpha, input.topK)
|
||||
}
|
||||
|
||||
@@ -385,7 +411,7 @@ export class KnowledgeIndexStore {
|
||||
}
|
||||
|
||||
async close(): Promise<void> {
|
||||
await this.driver.close()
|
||||
this.driver.close()
|
||||
}
|
||||
|
||||
/** Whether the backing driver has been closed (see {@link SqliteDriver.isClosed}). */
|
||||
@@ -401,14 +427,14 @@ export class KnowledgeIndexStore {
|
||||
}
|
||||
|
||||
/** Brute-force cosine scan over the plain-BLOB embedding column (no ANN index). */
|
||||
private async vectorSearch(queryEmbedding: number[], topK: number): Promise<KnowledgeIndexSearchMatch[]> {
|
||||
private vectorSearch(queryEmbedding: number[], topK: number): KnowledgeIndexSearchMatch[] {
|
||||
// Invariant, not a check: a base's embedding model and dimensions are immutable
|
||||
// (changing them means migrating to a new base), so `queryEmbedding` and every
|
||||
// stored `vector_blob` share one dimension — cosine never compares mismatched lengths.
|
||||
// `WHERE dist IS NOT NULL` drops degenerate (zero-norm) vectors: cosine distance is
|
||||
// undefined for them, and SQLite coerces that NULL/NaN to NULL — which would otherwise
|
||||
// sort first under `ORDER BY dist` and score `1 - Number(null) = 1`, outranking real hits.
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body,
|
||||
${this.vectorIndex.buildDistanceExpression('e.vector_blob')} AS dist
|
||||
FROM embedding e
|
||||
@@ -423,7 +449,7 @@ export class KnowledgeIndexStore {
|
||||
return result.rows.map((row) => toMatch(row, 1 - Number(row.dist)))
|
||||
}
|
||||
|
||||
private async bm25Search(queryText: string, topK: number): Promise<KnowledgeIndexSearchMatch[]> {
|
||||
private bm25Search(queryText: string, topK: number): KnowledgeIndexSearchMatch[] {
|
||||
// Short tokens (notably 1–2 char CJK words) produce no trigram, so MATCH would
|
||||
// silently return nothing — route those queries to the LIKE fallback instead.
|
||||
if (needsLikeFallback(queryText)) {
|
||||
@@ -433,7 +459,7 @@ export class KnowledgeIndexStore {
|
||||
if (!matchQuery) {
|
||||
return []
|
||||
}
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body, bm25(search_text_fts) AS score
|
||||
FROM search_text_fts
|
||||
JOIN search_text st
|
||||
@@ -455,13 +481,13 @@ export class KnowledgeIndexStore {
|
||||
* unit fully about the term) ranks first — and expose it as a higher-is-better
|
||||
* score so it fuses sanely with the vector lane in hybrid mode.
|
||||
*/
|
||||
private async bm25LikeSearch(tokens: string[], topK: number): Promise<KnowledgeIndexSearchMatch[]> {
|
||||
private bm25LikeSearch(tokens: string[], topK: number): KnowledgeIndexSearchMatch[] {
|
||||
if (tokens.length === 0) {
|
||||
return []
|
||||
}
|
||||
const likeClauses = tokens.map(() => `st.text LIKE ? ESCAPE '\\'`).join(' AND ')
|
||||
const args: SqlValue[] = [...tokens.map(toFtsLikePattern), topK]
|
||||
const result = await this.driver.execute(
|
||||
const result = this.driver.execute(
|
||||
`SELECT su.unit_id, su.material_id, su.unit_index, st.text AS body, length(st.text) AS len
|
||||
FROM search_text st
|
||||
JOIN search_unit su ON su.unit_id = st.target_id
|
||||
@@ -475,12 +501,12 @@ export class KnowledgeIndexStore {
|
||||
}
|
||||
|
||||
/** Throw (rolling back the surrounding rebuild) if any unit hash has no embedding row. */
|
||||
private async assertEmbeddingCoverage(tx: SqliteTransaction, materialId: string, hashes: string[]): Promise<void> {
|
||||
private assertEmbeddingCoverage(tx: SqliteTransaction, materialId: string, hashes: string[]): void {
|
||||
const missing = new Set(hashes)
|
||||
for (let i = 0; i < hashes.length; i += EMBEDDING_HASH_QUERY_BATCH) {
|
||||
const batch = hashes.slice(i, i + EMBEDDING_HASH_QUERY_BATCH)
|
||||
const placeholders = batch.map(() => '?').join(', ')
|
||||
const result = await tx.execute(
|
||||
const result = tx.execute(
|
||||
`SELECT embedding_text_hash FROM embedding WHERE embedding_text_hash IN (${placeholders})`,
|
||||
batch
|
||||
)
|
||||
@@ -495,8 +521,8 @@ export class KnowledgeIndexStore {
|
||||
}
|
||||
}
|
||||
|
||||
private async deleteMaterialSearchText(tx: SqliteTransaction, materialId: string): Promise<void> {
|
||||
await tx.execute(
|
||||
private deleteMaterialSearchText(tx: SqliteTransaction, materialId: string): void {
|
||||
tx.execute(
|
||||
`DELETE FROM search_text
|
||||
WHERE target_type = 'search_unit'
|
||||
AND target_id IN (SELECT unit_id FROM search_unit WHERE material_id = ?)`,
|
||||
|
||||
@@ -2,127 +2,114 @@ import { mkdtempSync, rmSync } from 'node:fs'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
|
||||
import type Database from 'better-sqlite3'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import { BetterSqlite3Driver, openBetterSqlite3IndexDriver } from '../BetterSqlite3Driver'
|
||||
|
||||
const loggerWarnMock = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({ warn: loggerWarnMock })
|
||||
}
|
||||
}))
|
||||
import type { BetterSqlite3Driver } from '../BetterSqlite3Driver'
|
||||
import { openBetterSqlite3IndexDriver } from '../BetterSqlite3Driver'
|
||||
|
||||
describe('BetterSqlite3Driver', () => {
|
||||
let tempDir: string
|
||||
let driver: BetterSqlite3Driver
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-driver-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await driver.execute('CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)')
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
driver.execute('CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)')
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await driver.close()
|
||||
afterEach(() => {
|
||||
driver.close()
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
it('enables foreign keys on open', async () => {
|
||||
const result = await driver.execute('PRAGMA foreign_keys')
|
||||
it('enables foreign keys on open', () => {
|
||||
const result = driver.execute('PRAGMA foreign_keys')
|
||||
expect(result.rows[0].foreign_keys).toBe(1)
|
||||
})
|
||||
|
||||
it('opens in WAL journal mode with a busy timeout so reads survive a concurrent write', async () => {
|
||||
const journal = await driver.execute('PRAGMA journal_mode')
|
||||
it('opens in WAL journal mode with a busy timeout so reads survive a concurrent write', () => {
|
||||
const journal = driver.execute('PRAGMA journal_mode')
|
||||
expect(String(journal.rows[0].journal_mode).toLowerCase()).toBe('wal')
|
||||
|
||||
const timeout = await driver.execute('PRAGMA busy_timeout')
|
||||
const timeout = driver.execute('PRAGMA busy_timeout')
|
||||
expect(Number(timeout.rows[0].timeout)).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('maps rows to plain objects', async () => {
|
||||
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
|
||||
it('maps rows to plain objects', () => {
|
||||
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
|
||||
|
||||
const select = await driver.execute('SELECT id, v FROM t WHERE id = ?', [1])
|
||||
const select = driver.execute('SELECT id, v FROM t WHERE id = ?', [1])
|
||||
expect(select.rows).toEqual([{ id: 1, v: 'a' }])
|
||||
})
|
||||
|
||||
it('commits a successful transaction', async () => {
|
||||
await driver.transaction(async (tx) => {
|
||||
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
|
||||
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'y'])
|
||||
it('reports rows changed by a write statement', () => {
|
||||
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
|
||||
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
|
||||
|
||||
const result = driver.execute('DELETE FROM t WHERE id = ?', [1])
|
||||
expect(result.changes).toBe(1)
|
||||
|
||||
const select = driver.execute('SELECT id FROM t')
|
||||
expect(select.changes).toBe(0)
|
||||
})
|
||||
|
||||
it('commits a successful transaction', () => {
|
||||
driver.transaction((tx) => {
|
||||
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
|
||||
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'y'])
|
||||
})
|
||||
|
||||
const count = await driver.execute('SELECT COUNT(*) AS n FROM t')
|
||||
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
|
||||
expect(count.rows[0].n).toBe(2)
|
||||
})
|
||||
|
||||
it('rolls back a failed transaction', async () => {
|
||||
await expect(
|
||||
driver.transaction(async (tx) => {
|
||||
await tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
|
||||
it('rolls back a failed transaction', () => {
|
||||
expect(() =>
|
||||
driver.transaction((tx) => {
|
||||
tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x'])
|
||||
throw new Error('boom')
|
||||
})
|
||||
).rejects.toThrow('boom')
|
||||
).toThrow('boom')
|
||||
|
||||
const count = await driver.execute('SELECT COUNT(*) AS n FROM t')
|
||||
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
|
||||
expect(count.rows[0].n).toBe(0)
|
||||
})
|
||||
|
||||
it('rethrows the original error when rollback also fails, instead of masking it', async () => {
|
||||
const originalError = new Error('insert failed')
|
||||
const rollbackError = new Error('rollback failed')
|
||||
// A fake better-sqlite3 connection: the bracket's BEGIN IMMEDIATE succeeds, the
|
||||
// body's statement fails with originalError, then the ROLLBACK that the catch
|
||||
// issues fails with rollbackError. The driver must surface originalError (what the
|
||||
// caller needs to diagnose the write) and only log the rollback failure.
|
||||
const fakeDb = {
|
||||
prepare: () => {
|
||||
throw originalError
|
||||
},
|
||||
exec: (sql: string) => {
|
||||
if (sql === 'ROLLBACK') {
|
||||
throw rollbackError
|
||||
}
|
||||
},
|
||||
pragma: () => undefined,
|
||||
close: () => undefined
|
||||
} as unknown as Database.Database
|
||||
const isolatedDriver = new BetterSqlite3Driver(fakeDb)
|
||||
it('throws if the transaction callback returns a promise, instead of silently committing early', () => {
|
||||
// better-sqlite3's native transaction() rejects an async callback outright (see
|
||||
// BetterSqlite3Driver.transaction doc) — an accidental async fn must fail loud
|
||||
// rather than commit before its awaited work actually ran.
|
||||
expect(() =>
|
||||
driver.transaction((tx) => {
|
||||
return Promise.resolve(tx.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'x']))
|
||||
})
|
||||
).toThrow(/promise/i)
|
||||
|
||||
await expect(isolatedDriver.transaction(async (tx) => tx.execute('INSERT INTO t (id) VALUES (1)'))).rejects.toBe(
|
||||
originalError
|
||||
)
|
||||
expect(loggerWarnMock).toHaveBeenCalledWith(
|
||||
'Failed to roll back knowledge index store transaction after an error',
|
||||
rollbackError
|
||||
)
|
||||
const count = driver.execute('SELECT COUNT(*) AS n FROM t')
|
||||
expect(count.rows[0].n).toBe(0)
|
||||
})
|
||||
|
||||
it('checkpoints but skips the VACUUM when the freed space is below the reclaim threshold', async () => {
|
||||
it('checkpoints but skips the VACUUM when the freed space is below the reclaim threshold', () => {
|
||||
// A small delete leaves a freelist far below the size/ratio thresholds, so reclaim
|
||||
// only truncates the WAL and reports that no whole-file rewrite ran.
|
||||
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
|
||||
await driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
|
||||
await driver.execute('DELETE FROM t')
|
||||
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [1, 'a'])
|
||||
driver.execute('INSERT INTO t (id, v) VALUES (?, ?)', [2, 'b'])
|
||||
driver.execute('DELETE FROM t')
|
||||
|
||||
const outcome = await driver.reclaim()
|
||||
const outcome = driver.reclaim()
|
||||
|
||||
expect(outcome).toEqual({ vacuumed: false, reclaimedBytes: 0 })
|
||||
})
|
||||
|
||||
it('reports closed state and rejects use after close with a deterministic error', async () => {
|
||||
it('reports closed state and rejects use after close with a deterministic error', () => {
|
||||
expect(driver.isClosed()).toBe(false)
|
||||
|
||||
await driver.close()
|
||||
driver.close()
|
||||
|
||||
expect(driver.isClosed()).toBe(true)
|
||||
await expect(driver.execute('SELECT 1')).rejects.toThrow(/closed/)
|
||||
await expect(driver.transaction(async (tx) => tx.execute('SELECT 1'))).rejects.toThrow(/closed/)
|
||||
expect(() => driver.execute('SELECT 1')).toThrow(/closed/)
|
||||
expect(() => driver.transaction((tx) => tx.execute('SELECT 1'))).toThrow(/closed/)
|
||||
// A second close (e.g. app shutdown after an explicit deleteStore) is a no-op.
|
||||
await expect(driver.close()).resolves.toBeUndefined()
|
||||
expect(driver.close()).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,14 +14,14 @@ describe('BetterSqlite3VectorIndex', () => {
|
||||
let driver: BetterSqlite3Driver
|
||||
const vectorIndex = new BetterSqlite3VectorIndex()
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-vindex-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await driver.close()
|
||||
afterEach(() => {
|
||||
driver.close()
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
@@ -39,23 +39,23 @@ describe('BetterSqlite3VectorIndex', () => {
|
||||
[vectorIndex.bindQueryVector(query), k]
|
||||
)
|
||||
|
||||
it('brute-force ranks nearest vectors first over a plain-BLOB column', async () => {
|
||||
await insertEmbedding('near', [1, 0, 0])
|
||||
await insertEmbedding('mid', [0.7, 0.7, 0])
|
||||
await insertEmbedding('far', [0, 0, 1])
|
||||
it('brute-force ranks nearest vectors first over a plain-BLOB column', () => {
|
||||
insertEmbedding('near', [1, 0, 0])
|
||||
insertEmbedding('mid', [0.7, 0.7, 0])
|
||||
insertEmbedding('far', [0, 0, 1])
|
||||
|
||||
const result = await topK([1, 0, 0], 3)
|
||||
const result = topK([1, 0, 0], 3)
|
||||
|
||||
expect(result.rows.map((row) => row.h)).toEqual(['near', 'mid', 'far'])
|
||||
expect(result.rows[0].dist as number).toBeLessThan(0.001)
|
||||
expect(result.rows[2].dist as number).toBeGreaterThan(0.9)
|
||||
})
|
||||
|
||||
it('respects the LIMIT k bound', async () => {
|
||||
await insertEmbedding('a', [1, 0, 0])
|
||||
await insertEmbedding('b', [0, 1, 0])
|
||||
it('respects the LIMIT k bound', () => {
|
||||
insertEmbedding('a', [1, 0, 0])
|
||||
insertEmbedding('b', [0, 1, 0])
|
||||
|
||||
const result = await topK([1, 0, 0], 1)
|
||||
const result = topK([1, 0, 0], 1)
|
||||
|
||||
expect(result.rows).toHaveLength(1)
|
||||
expect(result.rows[0].h).toBe('a')
|
||||
|
||||
@@ -24,10 +24,10 @@ describe('KnowledgeIndexStore integration (real better-sqlite3)', () => {
|
||||
let tempDir: string
|
||||
let store: KnowledgeIndexStore
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-integration-'))
|
||||
const driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
const driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
|
||||
})
|
||||
|
||||
|
||||
@@ -33,21 +33,21 @@ const bm25Row = (unitId: string, materialId: string, score: number): Record<stri
|
||||
*/
|
||||
function createFakeDriver(vectorRows: Array<Record<string, SqlValue>>, bm25Rows: Array<Record<string, SqlValue>>) {
|
||||
const limits: { vector?: number; bm25?: number } = {}
|
||||
const execute = vi.fn(async (sql: string, args: SqlValue[] = []): Promise<SqlQueryResult> => {
|
||||
const execute = vi.fn((sql: string, args: SqlValue[] = []): SqlQueryResult => {
|
||||
const limit = Number(args[args.length - 1])
|
||||
if (sql.includes('search_text_fts MATCH')) {
|
||||
limits.bm25 = limit
|
||||
return { rows: bm25Rows }
|
||||
return { rows: bm25Rows, changes: 0 }
|
||||
}
|
||||
limits.vector = limit
|
||||
return { rows: vectorRows }
|
||||
return { rows: vectorRows, changes: 0 }
|
||||
})
|
||||
const driver: SqliteDriver = {
|
||||
execute,
|
||||
transaction: async (fn) => fn({ execute }),
|
||||
reclaim: async () => ({ vacuumed: false, reclaimedBytes: 0 }),
|
||||
transaction: (fn) => fn({ execute }),
|
||||
reclaim: () => ({ vacuumed: false, reclaimedBytes: 0 }),
|
||||
isClosed: () => false,
|
||||
close: async () => undefined
|
||||
close: () => undefined
|
||||
}
|
||||
return { driver, limits }
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ describe('KnowledgeIndexStore.search', () => {
|
||||
let driver: BetterSqlite3Driver
|
||||
let store: KnowledgeIndexStore
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-search-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
|
||||
})
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ describe('KnowledgeIndexStore', () => {
|
||||
let driver: BetterSqlite3Driver
|
||||
let store: KnowledgeIndexStore
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-store-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
store = new KnowledgeIndexStore(driver, betterSqlite3VectorIndex)
|
||||
})
|
||||
|
||||
@@ -50,16 +50,14 @@ describe('KnowledgeIndexStore', () => {
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
const count = async (table: string) => Number((await driver.execute(`SELECT COUNT(*) AS n FROM ${table}`)).rows[0].n)
|
||||
const count = async (table: string) => Number(driver.execute(`SELECT COUNT(*) AS n FROM ${table}`).rows[0].n)
|
||||
|
||||
const ftsMatchCount = async (term: string) =>
|
||||
(
|
||||
await driver.execute(
|
||||
`SELECT st.search_text_id AS id
|
||||
FROM search_text_fts JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
|
||||
WHERE search_text_fts MATCH ?`,
|
||||
[term]
|
||||
)
|
||||
driver.execute(
|
||||
`SELECT st.search_text_id AS id
|
||||
FROM search_text_fts JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
|
||||
WHERE search_text_fts MATCH ?`,
|
||||
[term]
|
||||
).rows.length
|
||||
|
||||
// The reliable external-content desync detector: the default `integrity-check` does NOT compare
|
||||
@@ -86,7 +84,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
expect(await count('search_unit')).toBe(2)
|
||||
expect(await count('search_text')).toBe(2)
|
||||
|
||||
const material = await driver.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, ['m1'])
|
||||
const material = driver.execute(`SELECT current_content_hash FROM material WHERE material_id = ?`, ['m1'])
|
||||
expect(material.rows[0].current_content_hash).not.toBeNull()
|
||||
})
|
||||
|
||||
@@ -112,7 +110,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
// Corrupt the store: drop the body row out from under the unit. The same
|
||||
// damage silently excludes the unit from search (INNER JOIN); the list lane
|
||||
// must fail loudly rather than add a third symptom (existing-but-empty chunk).
|
||||
await driver.execute(`DELETE FROM search_text WHERE target_type = 'search_unit' AND kind = 'body'`)
|
||||
driver.execute(`DELETE FROM search_text WHERE target_type = 'search_unit' AND kind = 'body'`)
|
||||
|
||||
await expect(store.listMaterialUnits('m1')).rejects.toThrow('missing the body text for unit')
|
||||
})
|
||||
@@ -321,7 +319,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
// single GC pass — the path a folder delete takes (one deleteMaterials over N files).
|
||||
await store.deleteMaterials(['m1', 'm2', 'm1'])
|
||||
|
||||
expect((await driver.execute(`SELECT material_id FROM material`)).rows.map((r) => r.material_id)).toEqual(['m3'])
|
||||
expect(driver.execute(`SELECT material_id FROM material`).rows.map((r) => r.material_id)).toEqual(['m3'])
|
||||
expect(await store.listMaterialUnits('m1')).toEqual([])
|
||||
expect(await store.listMaterialUnits('m2')).toEqual([])
|
||||
// The single end-of-batch GC must sweep m2's now-orphaned body/embedding/content while
|
||||
@@ -417,7 +415,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
|
||||
expect(outcome.vacuumed).toBe(true)
|
||||
expect(outcome.reclaimedBytes).toBeGreaterThan(0)
|
||||
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
|
||||
expect(ftsIntegrityCheck()).toBeDefined()
|
||||
// The survivor stays both keyword- and vector-reachable after the rewrite.
|
||||
expect(await ftsMatchCount('knowledge')).toBe(1)
|
||||
expect((await store.search({ queryText: 'knowledge', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual([
|
||||
@@ -466,7 +464,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
// VACUUM renumbers implicit rowids; assert the external-content FTS did NOT desync. Keyed on the
|
||||
// stable fts_rowid it stays aligned by construction — verified with the reliable integrity check
|
||||
// (the default integrity-check would pass even on a desync).
|
||||
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
|
||||
expect(ftsIntegrityCheck()).toBeDefined()
|
||||
})
|
||||
|
||||
it('keeps search_text_fts aligned after a rowid-reshuffling rebuild (fts_rowid keying)', async () => {
|
||||
@@ -485,23 +483,23 @@ describe('KnowledgeIndexStore', () => {
|
||||
// Reshuffle the implicit rowid exactly as a table rebuild / VACUUM does: copy the table (new
|
||||
// contiguous rowids, dropping the hole) while carrying fts_rowid verbatim, then re-assert the
|
||||
// dropped indexes + triggers. The FTS vtable is untouched, so it stays keyed by fts_rowid.
|
||||
await driver.execute('PRAGMA foreign_keys=OFF')
|
||||
await driver.execute('CREATE TABLE __new_search_text AS SELECT * FROM search_text')
|
||||
await driver.execute('DROP TABLE search_text')
|
||||
await driver.execute('ALTER TABLE __new_search_text RENAME TO search_text')
|
||||
await driver.execute('PRAGMA foreign_keys=ON')
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver.execute('PRAGMA foreign_keys=OFF')
|
||||
driver.execute('CREATE TABLE __new_search_text AS SELECT * FROM search_text')
|
||||
driver.execute('DROP TABLE search_text')
|
||||
driver.execute('ALTER TABLE __new_search_text RENAME TO search_text')
|
||||
driver.execute('PRAGMA foreign_keys=ON')
|
||||
createKnowledgeIndexSchema(driver)
|
||||
|
||||
// fts_rowid was copied through the rebuild, so the index stays aligned: the reliable check passes
|
||||
// and the survivors resolve to the right materials (a rowid-keyed index would return wrong/empty).
|
||||
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
|
||||
expect(ftsIntegrityCheck()).toBeDefined()
|
||||
expect((await store.search({ queryText: 'cherry', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual([
|
||||
'm3'
|
||||
])
|
||||
expect((await store.search({ queryText: 'date', mode: 'bm25', topK: 10 })).map((h) => h.materialId)).toEqual(['m4'])
|
||||
// The reshuffle reassigned implicit rowids while leaving fts_rowid untouched and unique.
|
||||
expect(await count('search_text')).toBe(3)
|
||||
expect(Number((await driver.execute(`SELECT COUNT(DISTINCT fts_rowid) AS n FROM search_text`)).rows[0].n)).toBe(3)
|
||||
expect(Number(driver.execute(`SELECT COUNT(DISTINCT fts_rowid) AS n FROM search_text`).rows[0].n)).toBe(3)
|
||||
})
|
||||
|
||||
it('integrity-check,1 catches a NULL fts_rowid desync (and proves the detector is live)', async () => {
|
||||
@@ -513,8 +511,8 @@ describe('KnowledgeIndexStore', () => {
|
||||
// now references a key the content row no longer carries. This both guards that hazard AND is the
|
||||
// positive control that ftsIntegrityCheck() really throws on a desync — every other use of it in
|
||||
// this suite asserts it resolves, so without this case those assertions could pass vacuously.
|
||||
await driver.execute(`UPDATE search_text SET fts_rowid = NULL`)
|
||||
await expect(ftsIntegrityCheck()).rejects.toThrow()
|
||||
driver.execute(`UPDATE search_text SET fts_rowid = NULL`)
|
||||
expect(() => ftsIntegrityCheck()).toThrow()
|
||||
})
|
||||
|
||||
it('reuses a freed fts_rowid without leaving a stale FTS entry', async () => {
|
||||
@@ -527,14 +525,14 @@ describe('KnowledgeIndexStore', () => {
|
||||
await store.rebuildMaterial('m3', buildInput('charlie cherry body', [[0, 19]], 'c.md'))
|
||||
|
||||
// m3 holds the current MAX fts_rowid. Delete it to free that rowid.
|
||||
const maxBefore = Number((await driver.execute(`SELECT MAX(fts_rowid) AS m FROM search_text`)).rows[0].m)
|
||||
const maxBefore = Number(driver.execute(`SELECT MAX(fts_rowid) AS m FROM search_text`).rows[0].m)
|
||||
await store.deleteMaterials(['m3'])
|
||||
expect(await ftsMatchCount('cherry')).toBe(0)
|
||||
|
||||
// The next insert's MAX(fts_rowid)+1 reuses the value m3 just vacated.
|
||||
await store.rebuildMaterial('m4', buildInput('delta dragon body', [[0, 17]], 'd.md'))
|
||||
const reused = Number(
|
||||
(await driver.execute(`SELECT fts_rowid FROM search_text WHERE text LIKE 'delta%'`)).rows[0].fts_rowid
|
||||
driver.execute(`SELECT fts_rowid FROM search_text WHERE text LIKE 'delta%'`).rows[0].fts_rowid
|
||||
)
|
||||
expect(reused).toBe(maxBefore) // landed on the exact physical rowid m3 had held
|
||||
|
||||
@@ -544,7 +542,7 @@ describe('KnowledgeIndexStore', () => {
|
||||
'm4'
|
||||
])
|
||||
expect(await ftsMatchCount('cherry')).toBe(0)
|
||||
await expect(ftsIntegrityCheck()).resolves.toBeDefined()
|
||||
expect(ftsIntegrityCheck()).toBeDefined()
|
||||
})
|
||||
|
||||
it('keeps a shared embedding reachable for the other material after rebuilding the one that introduced it', async () => {
|
||||
|
||||
@@ -17,21 +17,21 @@ describe('ensureIndexMeta', () => {
|
||||
let tempDir: string
|
||||
let driver: BetterSqlite3Driver
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-meta-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await driver.close()
|
||||
afterEach(() => {
|
||||
driver.close()
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
it('writes the single meta identity row with the schema version and base id on first open', async () => {
|
||||
await ensureIndexMeta(driver, META_INPUT)
|
||||
it('writes the single meta identity row with the schema version and base id on first open', () => {
|
||||
ensureIndexMeta(driver, META_INPUT)
|
||||
|
||||
const result = await driver.execute('SELECT * FROM meta')
|
||||
const result = driver.execute('SELECT * FROM meta')
|
||||
expect(result.rows).toHaveLength(1)
|
||||
const row = result.rows[0]
|
||||
expect(row.id).toBe(1)
|
||||
@@ -39,24 +39,22 @@ describe('ensureIndexMeta', () => {
|
||||
expect(row.base_id).toBe('kb-1')
|
||||
})
|
||||
|
||||
it('is idempotent across re-opens: the original row is kept, not duplicated or rewritten', async () => {
|
||||
await ensureIndexMeta(driver, META_INPUT)
|
||||
const first = await driver.execute('SELECT created_at FROM meta WHERE id = 1')
|
||||
it('is idempotent across re-opens: the original row is kept, not duplicated or rewritten', () => {
|
||||
ensureIndexMeta(driver, META_INPUT)
|
||||
const first = driver.execute('SELECT created_at FROM meta WHERE id = 1')
|
||||
|
||||
await ensureIndexMeta(driver, META_INPUT)
|
||||
const second = await driver.execute('SELECT created_at FROM meta WHERE id = 1')
|
||||
ensureIndexMeta(driver, META_INPUT)
|
||||
const second = driver.execute('SELECT created_at FROM meta WHERE id = 1')
|
||||
|
||||
const count = await driver.execute('SELECT COUNT(*) AS n FROM meta')
|
||||
const count = driver.execute('SELECT COUNT(*) AS n FROM meta')
|
||||
expect(count.rows[0].n).toBe(1)
|
||||
expect(second.rows[0].created_at).toBe(first.rows[0].created_at)
|
||||
})
|
||||
|
||||
it('rejects opening an index that belongs to a different base (anti-mismount guard, §4.1)', async () => {
|
||||
await ensureIndexMeta(driver, META_INPUT)
|
||||
it('rejects opening an index that belongs to a different base (anti-mismount guard, §4.1)', () => {
|
||||
ensureIndexMeta(driver, META_INPUT)
|
||||
|
||||
await expect(ensureIndexMeta(driver, { ...META_INPUT, baseId: 'kb-OTHER' })).rejects.toThrow(
|
||||
/belongs to a different base/
|
||||
)
|
||||
expect(() => ensureIndexMeta(driver, { ...META_INPUT, baseId: 'kb-OTHER' })).toThrow(/belongs to a different base/)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -67,25 +65,25 @@ describe('index content diagnostics', () => {
|
||||
let tempDir: string
|
||||
let driver: BetterSqlite3Driver
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-meta-'))
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await driver.close()
|
||||
afterEach(() => {
|
||||
driver.close()
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
it('hasAnyMaterial is false on a fresh index and true once a material row exists', async () => {
|
||||
expect(await hasAnyMaterial(driver)).toBe(false)
|
||||
it('hasAnyMaterial is false on a fresh index and true once a material row exists', () => {
|
||||
expect(hasAnyMaterial(driver)).toBe(false)
|
||||
|
||||
await driver.execute(
|
||||
driver.execute(
|
||||
`INSERT INTO material (material_id, relative_path, created_at, updated_at)
|
||||
VALUES ('m1', 'doc.md', 1, 1)`
|
||||
)
|
||||
|
||||
expect(await hasAnyMaterial(driver)).toBe(true)
|
||||
expect(hasAnyMaterial(driver)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -20,16 +20,16 @@ describe('knowledge index schema', () => {
|
||||
let tempDir: string
|
||||
let driver: BetterSqlite3Driver
|
||||
|
||||
beforeEach(async () => {
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), 'cs-knowledge-index-'))
|
||||
// openBetterSqlite3IndexDriver enables foreign keys per-connection (for CASCADE)
|
||||
// and loads sqlite-vec (for vec_distance_cosine).
|
||||
driver = await openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
await createKnowledgeIndexSchema(driver)
|
||||
driver = openBetterSqlite3IndexDriver(join(tempDir, 'index.sqlite'))
|
||||
createKnowledgeIndexSchema(driver)
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await driver?.close()
|
||||
afterEach(() => {
|
||||
driver?.close()
|
||||
rmSync(tempDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
@@ -58,15 +58,15 @@ describe('knowledge index schema', () => {
|
||||
)
|
||||
|
||||
/** Every schema object (tables, triggers, indexes, FTS shadow tables), as stable `type:name` keys. */
|
||||
const listSchemaObjects = async () => {
|
||||
const result = await driver.execute(`SELECT type, name FROM sqlite_master ORDER BY type, name`)
|
||||
const listSchemaObjects = () => {
|
||||
const result = driver.execute(`SELECT type, name FROM sqlite_master ORDER BY type, name`)
|
||||
return result.rows.map((row) => `${row.type}:${row.name}`)
|
||||
}
|
||||
|
||||
describe('schema creation', () => {
|
||||
it('creates all 7 schema objects', async () => {
|
||||
it('creates all 7 schema objects', () => {
|
||||
const expected = ['meta', 'content', 'material', 'search_unit', 'search_text', 'embedding', 'search_text_fts']
|
||||
const result = await driver.execute(
|
||||
const result = driver.execute(
|
||||
`SELECT name FROM sqlite_master WHERE name IN (${expected.map(() => '?').join(', ')})`,
|
||||
expected
|
||||
)
|
||||
@@ -76,10 +76,10 @@ describe('knowledge index schema', () => {
|
||||
}
|
||||
})
|
||||
|
||||
it('is idempotent: re-applying through the driver leaves the object set unchanged', async () => {
|
||||
const objectsBefore = await listSchemaObjects()
|
||||
await expect(createKnowledgeIndexSchema(driver)).resolves.toBeUndefined()
|
||||
expect(await listSchemaObjects()).toEqual(objectsBefore)
|
||||
it('is idempotent: re-applying through the driver leaves the object set unchanged', () => {
|
||||
const objectsBefore = listSchemaObjects()
|
||||
expect(createKnowledgeIndexSchema(driver)).toBeUndefined()
|
||||
expect(listSchemaObjects()).toEqual(objectsBefore)
|
||||
})
|
||||
|
||||
it('exposes a static, parameterless statement list', () => {
|
||||
@@ -98,60 +98,60 @@ describe('knowledge index schema', () => {
|
||||
[id, TS, TS]
|
||||
)
|
||||
|
||||
it('accepts the single id = 1 row', async () => {
|
||||
await expect(insertMeta(1)).resolves.toBeDefined()
|
||||
it('accepts the single id = 1 row', () => {
|
||||
expect(insertMeta(1)).toBeDefined()
|
||||
})
|
||||
|
||||
it('rejects id != 1', async () => {
|
||||
await expect(insertMeta(2)).rejects.toThrow()
|
||||
it('rejects id != 1', () => {
|
||||
expect(() => insertMeta(2)).toThrow()
|
||||
})
|
||||
|
||||
it('rejects a second row', async () => {
|
||||
await insertMeta(1)
|
||||
await expect(insertMeta(1)).rejects.toThrow()
|
||||
it('rejects a second row', () => {
|
||||
insertMeta(1)
|
||||
expect(() => insertMeta(1)).toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('material constraints', () => {
|
||||
it('accepts a valid material', async () => {
|
||||
await expect(insertMaterial('m1', 'docs/paper.md')).resolves.toBeDefined()
|
||||
it('accepts a valid material', () => {
|
||||
expect(insertMaterial('m1', 'docs/paper.md')).toBeDefined()
|
||||
})
|
||||
|
||||
it('rejects an absolute relative_path', async () => {
|
||||
await expect(insertMaterial('m1', '/abs/paper.md')).rejects.toThrow()
|
||||
it('rejects an absolute relative_path', () => {
|
||||
expect(() => insertMaterial('m1', '/abs/paper.md')).toThrow()
|
||||
})
|
||||
|
||||
it('rejects a reserved .cherry relative_path', async () => {
|
||||
await expect(insertMaterial('m1', '.cherry/index.sqlite')).rejects.toThrow()
|
||||
it('rejects a reserved .cherry relative_path', () => {
|
||||
expect(() => insertMaterial('m1', '.cherry/index.sqlite')).toThrow()
|
||||
})
|
||||
|
||||
it('enforces unique relative_path', async () => {
|
||||
await insertMaterial('m1', 'a.md')
|
||||
await expect(insertMaterial('m2', 'a.md')).rejects.toThrow()
|
||||
it('enforces unique relative_path', () => {
|
||||
insertMaterial('m1', 'a.md')
|
||||
expect(() => insertMaterial('m2', 'a.md')).toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('foreign keys', () => {
|
||||
it('cascades search_unit deletion when its material is deleted', async () => {
|
||||
await insertContent('h1', 'hello')
|
||||
await insertMaterial('m1', 'a.md', 'h1')
|
||||
await insertSearchUnit('u1', 'm1', 'h1')
|
||||
it('cascades search_unit deletion when its material is deleted', () => {
|
||||
insertContent('h1', 'hello')
|
||||
insertMaterial('m1', 'a.md', 'h1')
|
||||
insertSearchUnit('u1', 'm1', 'h1')
|
||||
|
||||
await driver.execute(`DELETE FROM material WHERE material_id = ?`, ['m1'])
|
||||
driver.execute(`DELETE FROM material WHERE material_id = ?`, ['m1'])
|
||||
|
||||
const remaining = await driver.execute(`SELECT COUNT(*) AS n FROM search_unit`)
|
||||
const remaining = driver.execute(`SELECT COUNT(*) AS n FROM search_unit`)
|
||||
expect(remaining.rows[0].n).toBe(0)
|
||||
})
|
||||
|
||||
it('rejects a search_unit referencing a missing material', async () => {
|
||||
await insertContent('h1', 'hello')
|
||||
await expect(insertSearchUnit('u1', 'missing-material', 'h1')).rejects.toThrow()
|
||||
it('rejects a search_unit referencing a missing material', () => {
|
||||
insertContent('h1', 'hello')
|
||||
expect(() => insertSearchUnit('u1', 'missing-material', 'h1')).toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('FTS5 (trigram, external content)', () => {
|
||||
const matchBody = async (term: string) => {
|
||||
const result = await driver.execute(
|
||||
const matchBody = (term: string) => {
|
||||
const result = driver.execute(
|
||||
`SELECT st.search_text_id AS id
|
||||
FROM search_text_fts
|
||||
JOIN search_text st ON st.fts_rowid = search_text_fts.rowid
|
||||
@@ -161,52 +161,50 @@ describe('knowledge index schema', () => {
|
||||
return result.rows.map((row) => row.id as string)
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
await insertContent('h1', 'body content')
|
||||
await insertMaterial('m1', 'a.md', 'h1')
|
||||
await insertSearchUnit('u1', 'm1', 'h1')
|
||||
beforeEach(() => {
|
||||
insertContent('h1', 'body content')
|
||||
insertMaterial('m1', 'a.md', 'h1')
|
||||
insertSearchUnit('u1', 'm1', 'h1')
|
||||
})
|
||||
|
||||
it('indexes inserted search_text and matches by term', async () => {
|
||||
await insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
|
||||
expect(await matchBody('knowledge')).toEqual(['st1'])
|
||||
it('indexes inserted search_text and matches by term', () => {
|
||||
insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
|
||||
expect(matchBody('knowledge')).toEqual(['st1'])
|
||||
})
|
||||
|
||||
it('removes the FTS entry when search_text is deleted (ad trigger)', async () => {
|
||||
await insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
|
||||
expect(await matchBody('knowledge')).toEqual(['st1'])
|
||||
it('removes the FTS entry when search_text is deleted (ad trigger)', () => {
|
||||
insertSearchText('st1', 'u1', 'the quick brown fox jumps over knowledge base', 'eh1')
|
||||
expect(matchBody('knowledge')).toEqual(['st1'])
|
||||
|
||||
await driver.execute(`DELETE FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
expect(await matchBody('knowledge')).toEqual([])
|
||||
driver.execute(`DELETE FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
expect(matchBody('knowledge')).toEqual([])
|
||||
})
|
||||
|
||||
it('re-syncs the FTS entry when search_text.text is updated (au trigger)', async () => {
|
||||
it('re-syncs the FTS entry when search_text.text is updated (au trigger)', () => {
|
||||
// Production rebuilds are delete + insert, so this UPDATE path has no caller
|
||||
// today; the trigger is kept defensively and this test pins its behavior.
|
||||
await insertSearchText('st1', 'u1', 'alpha knowledge base', 'eh1')
|
||||
expect(await matchBody('knowledge')).toEqual(['st1'])
|
||||
const ftsRowidBefore = (
|
||||
await driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
).rows[0].fts_rowid
|
||||
insertSearchText('st1', 'u1', 'alpha knowledge base', 'eh1')
|
||||
expect(matchBody('knowledge')).toEqual(['st1'])
|
||||
const ftsRowidBefore = driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
.rows[0].fts_rowid
|
||||
|
||||
await driver.execute(`UPDATE search_text SET text = ? WHERE search_text_id = ?`, ['beta wisdom corpus', 'st1'])
|
||||
driver.execute(`UPDATE search_text SET text = ? WHERE search_text_id = ?`, ['beta wisdom corpus', 'st1'])
|
||||
|
||||
expect(await matchBody('knowledge')).toEqual([])
|
||||
expect(await matchBody('wisdom')).toEqual(['st1'])
|
||||
expect(matchBody('knowledge')).toEqual([])
|
||||
expect(matchBody('wisdom')).toEqual(['st1'])
|
||||
// fts_rowid is stable across a text edit (the au trigger re-keys the FTS row by NEW.fts_rowid,
|
||||
// it does not reassign it) — so the external-content index stays aligned.
|
||||
const ftsRowidAfter = (
|
||||
await driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
).rows[0].fts_rowid
|
||||
const ftsRowidAfter = driver.execute(`SELECT fts_rowid FROM search_text WHERE search_text_id = ?`, ['st1'])
|
||||
.rows[0].fts_rowid
|
||||
expect(ftsRowidAfter).toBe(ftsRowidBefore)
|
||||
await expect(
|
||||
expect(
|
||||
driver.execute(`INSERT INTO search_text_fts(search_text_fts, rank) VALUES('integrity-check', 1)`)
|
||||
).resolves.toBeDefined()
|
||||
).toBeDefined()
|
||||
})
|
||||
|
||||
it('exposes a bm25 rank for matches', async () => {
|
||||
await insertSearchText('st1', 'u1', 'knowledge retrieval', 'eh1')
|
||||
const result = await driver.execute(
|
||||
it('exposes a bm25 rank for matches', () => {
|
||||
insertSearchText('st1', 'u1', 'knowledge retrieval', 'eh1')
|
||||
const result = driver.execute(
|
||||
`SELECT bm25(search_text_fts) AS score
|
||||
FROM search_text_fts
|
||||
WHERE search_text_fts MATCH ?`,
|
||||
@@ -218,15 +216,15 @@ describe('knowledge index schema', () => {
|
||||
})
|
||||
|
||||
describe('embedding vector (engine-portability spike, §5.6)', () => {
|
||||
it('computes vec_distance_cosine directly over a plain BLOB column', async () => {
|
||||
it('computes vec_distance_cosine directly over a plain BLOB column', () => {
|
||||
const vector = [0.1, 0.2, 0.3]
|
||||
await driver.execute(`INSERT INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
|
||||
driver.execute(`INSERT INTO embedding (embedding_text_hash, vector_blob, created_at) VALUES (?, ?, ?)`, [
|
||||
'eh_vec',
|
||||
encodeVectorBlob(vector),
|
||||
TS
|
||||
])
|
||||
|
||||
const result = await driver.execute(
|
||||
const result = driver.execute(
|
||||
`SELECT vec_distance_cosine(vector_blob, ?) AS dist
|
||||
FROM embedding
|
||||
WHERE embedding_text_hash = ?`,
|
||||
@@ -241,29 +239,29 @@ describe('knowledge index schema', () => {
|
||||
})
|
||||
|
||||
describe('schema version & rebuild (open-time migration)', () => {
|
||||
it('reports null until a meta row exists, then the current constant', async () => {
|
||||
it('reports null until a meta row exists, then the current constant', () => {
|
||||
// beforeEach created the schema (meta table exists) but no id=1 row yet.
|
||||
expect(await readIndexSchemaVersion(driver)).toBeNull()
|
||||
await ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
expect(await readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
|
||||
expect(readIndexSchemaVersion(driver)).toBeNull()
|
||||
ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
expect(readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
|
||||
})
|
||||
|
||||
it('reports null on a file with no meta table at all', async () => {
|
||||
const fresh = await openBetterSqlite3IndexDriver(join(tempDir, 'no-meta.sqlite'))
|
||||
it('reports null on a file with no meta table at all', () => {
|
||||
const fresh = openBetterSqlite3IndexDriver(join(tempDir, 'no-meta.sqlite'))
|
||||
try {
|
||||
expect(await readIndexSchemaVersion(fresh)).toBeNull()
|
||||
expect(readIndexSchemaVersion(fresh)).toBeNull()
|
||||
} finally {
|
||||
await fresh.close()
|
||||
fresh.close()
|
||||
}
|
||||
})
|
||||
|
||||
it('reports null for a malformed (non-numeric) schema_version cell', async () => {
|
||||
await ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
it('reports null for a malformed (non-numeric) schema_version cell', () => {
|
||||
ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
// A blanked/corrupt version (stored as text under the column's INTEGER affinity) must read as
|
||||
// "unknown" → null, so the open path treats it as fresh and creates rather than silently
|
||||
// mistaking it for a real version. Covers the `typeof === 'number'` guard.
|
||||
await driver.execute(`UPDATE meta SET schema_version = 'corrupt'`)
|
||||
expect(await readIndexSchemaVersion(driver)).toBeNull()
|
||||
driver.execute(`UPDATE meta SET schema_version = 'corrupt'`)
|
||||
expect(readIndexSchemaVersion(driver)).toBeNull()
|
||||
})
|
||||
|
||||
it('pins the current schema version (a deliberate bump-me tripwire)', () => {
|
||||
@@ -273,25 +271,25 @@ describe('knowledge index schema', () => {
|
||||
expect(KNOWLEDGE_INDEX_SCHEMA_VERSION).toBe(2)
|
||||
})
|
||||
|
||||
it('resetKnowledgeIndexSchema wipes data, rebuilds every object, and lets meta restamp the version', async () => {
|
||||
const freshObjects = await listSchemaObjects()
|
||||
it('resetKnowledgeIndexSchema wipes data, rebuilds every object, and lets meta restamp the version', () => {
|
||||
const freshObjects = listSchemaObjects()
|
||||
// Seed a populated index stamped at an older layout version.
|
||||
await ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
await insertContent('h1', 'hello world')
|
||||
await insertMaterial('m1', 'a.md', 'h1')
|
||||
await driver.execute(`UPDATE meta SET schema_version = 1`)
|
||||
expect(await readIndexSchemaVersion(driver)).toBe(1)
|
||||
ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
insertContent('h1', 'hello world')
|
||||
insertMaterial('m1', 'a.md', 'h1')
|
||||
driver.execute(`UPDATE meta SET schema_version = 1`)
|
||||
expect(readIndexSchemaVersion(driver)).toBe(1)
|
||||
|
||||
await resetKnowledgeIndexSchema(driver)
|
||||
resetKnowledgeIndexSchema(driver)
|
||||
|
||||
// Same object set as a fresh schema, but the derived data is gone (rebuildable artifact).
|
||||
expect(await listSchemaObjects()).toEqual(freshObjects)
|
||||
expect((await driver.execute(`SELECT COUNT(*) AS n FROM material`)).rows[0].n).toBe(0)
|
||||
expect((await driver.execute(`SELECT COUNT(*) AS n FROM content`)).rows[0].n).toBe(0)
|
||||
expect(listSchemaObjects()).toEqual(freshObjects)
|
||||
expect(driver.execute(`SELECT COUNT(*) AS n FROM material`).rows[0].n).toBe(0)
|
||||
expect(driver.execute(`SELECT COUNT(*) AS n FROM content`).rows[0].n).toBe(0)
|
||||
// The reset drops meta too, so the version is null until the open path restamps it.
|
||||
expect(await readIndexSchemaVersion(driver)).toBeNull()
|
||||
await ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
expect(await readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
|
||||
expect(readIndexSchemaVersion(driver)).toBeNull()
|
||||
ensureIndexMeta(driver, { baseId: 'base-1' })
|
||||
expect(readIndexSchemaVersion(driver)).toBe(KNOWLEDGE_INDEX_SCHEMA_VERSION)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -20,15 +20,15 @@ export interface IndexMetaInput {
|
||||
* the store-open path logs an error when that happens under a base that already
|
||||
* has completed items (see KnowledgeVectorStoreService).
|
||||
*/
|
||||
export async function ensureIndexMeta(executor: SqliteExecutor, input: IndexMetaInput): Promise<void> {
|
||||
export function ensureIndexMeta(executor: SqliteExecutor, input: IndexMetaInput): void {
|
||||
const now = Date.now()
|
||||
await executor.execute(
|
||||
executor.execute(
|
||||
`INSERT OR IGNORE INTO meta (id, schema_version, base_id, created_at, updated_at)
|
||||
VALUES (1, ?, ?, ?, ?)`,
|
||||
[KNOWLEDGE_INDEX_SCHEMA_VERSION, input.baseId, now, now]
|
||||
)
|
||||
|
||||
const stored = await executor.execute(`SELECT base_id FROM meta WHERE id = 1`)
|
||||
const stored = executor.execute(`SELECT base_id FROM meta WHERE id = 1`)
|
||||
const storedBaseId = stored.rows[0]?.base_id as string | undefined
|
||||
if (storedBaseId !== input.baseId) {
|
||||
throw new Error(
|
||||
@@ -44,18 +44,18 @@ export async function ensureIndexMeta(executor: SqliteExecutor, input: IndexMeta
|
||||
* The store-open path compares this to {@link KNOWLEDGE_INDEX_SCHEMA_VERSION}: a
|
||||
* non-null mismatch means an old layout that must be rebuilt before the DDL is applied.
|
||||
*/
|
||||
export async function readIndexSchemaVersion(executor: SqliteExecutor): Promise<number | null> {
|
||||
const hasMeta = await executor.execute(`SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'meta'`)
|
||||
export function readIndexSchemaVersion(executor: SqliteExecutor): number | null {
|
||||
const hasMeta = executor.execute(`SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'meta'`)
|
||||
if (hasMeta.rows.length === 0) {
|
||||
return null
|
||||
}
|
||||
const result = await executor.execute(`SELECT schema_version FROM meta WHERE id = 1`)
|
||||
const result = executor.execute(`SELECT schema_version FROM meta WHERE id = 1`)
|
||||
const version = result.rows[0]?.schema_version
|
||||
return typeof version === 'number' ? version : null
|
||||
}
|
||||
|
||||
/** Whether the index database holds at least one material row (store-open diagnostics probe). */
|
||||
export async function hasAnyMaterial(executor: SqliteExecutor): Promise<boolean> {
|
||||
const result = await executor.execute(`SELECT 1 FROM material LIMIT 1`)
|
||||
export function hasAnyMaterial(executor: SqliteExecutor): boolean {
|
||||
const result = executor.execute(`SELECT 1 FROM material LIMIT 1`)
|
||||
return result.rows.length > 0
|
||||
}
|
||||
|
||||
@@ -214,9 +214,9 @@ export const KNOWLEDGE_INDEX_SCHEMA_STATEMENTS: readonly string[] = [
|
||||
* Does NOT insert the `meta` row — that requires a runtime value (the base id)
|
||||
* and is owned by the store-open path.
|
||||
*/
|
||||
export async function createKnowledgeIndexSchema(executor: SqliteExecutor): Promise<void> {
|
||||
export function createKnowledgeIndexSchema(executor: SqliteExecutor): void {
|
||||
for (const statement of KNOWLEDGE_INDEX_SCHEMA_STATEMENTS) {
|
||||
await executor.execute(statement)
|
||||
executor.execute(statement)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,9 +249,9 @@ export const KNOWLEDGE_INDEX_DROP_STATEMENTS: readonly string[] = [
|
||||
* the base simply re-indexes. Drops run first (autocommit per statement), then the
|
||||
* current DDL is reapplied. Callers must restamp `meta` afterwards (the DROP removes it).
|
||||
*/
|
||||
export async function resetKnowledgeIndexSchema(executor: SqliteExecutor): Promise<void> {
|
||||
export function resetKnowledgeIndexSchema(executor: SqliteExecutor): void {
|
||||
for (const statement of KNOWLEDGE_INDEX_DROP_STATEMENTS) {
|
||||
await executor.execute(statement)
|
||||
executor.execute(statement)
|
||||
}
|
||||
await createKnowledgeIndexSchema(executor)
|
||||
createKnowledgeIndexSchema(executor)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,14 @@
|
||||
* sqlite-vec today) with zero user migration. Only the driver and VectorIndex
|
||||
* adapters are engine-specific; everything above them — the schema DDL
|
||||
* (schema.ts) and the store's queries — is shared, engine-neutral SQL.
|
||||
*
|
||||
* Synchronous by design, mirroring `DbService.withWriteTx` (src/main/data/db/DbService.ts):
|
||||
* better-sqlite3 is a single synchronous connection with no I/O wait, so an `async`
|
||||
* surface here would be pure libsql-era residue. It is not just cosmetic — a
|
||||
* transaction callback that actually awaits real async work would yield the event
|
||||
* loop while `BEGIN`..`COMMIT` is open, letting an unrelated read on the same
|
||||
* connection observe uncommitted rows. A synchronous `transaction<T>(fn: (tx) => T): T`
|
||||
* makes that categorically impossible: `fn` runs to completion in one JS turn.
|
||||
*/
|
||||
|
||||
/** A value bindable to a statement parameter or read back from a result column. */
|
||||
@@ -13,11 +21,13 @@ export type SqlValue = string | number | bigint | boolean | Uint8Array | ArrayBu
|
||||
|
||||
export interface SqlQueryResult {
|
||||
rows: Array<Record<string, SqlValue>>
|
||||
/** Rows inserted/updated/deleted by this statement (0 for a read). */
|
||||
changes: number
|
||||
}
|
||||
|
||||
/** Runs a single statement. Implemented by both the driver and a transaction handle. */
|
||||
export interface SqliteExecutor {
|
||||
execute(sql: string, args?: SqlValue[]): Promise<SqlQueryResult>
|
||||
execute(sql: string, args?: SqlValue[]): SqlQueryResult
|
||||
}
|
||||
|
||||
/** A handle valid only inside SqliteDriver.transaction(); same surface as the driver. */
|
||||
@@ -33,11 +43,13 @@ export interface SqliteReclaimOutcome {
|
||||
|
||||
export interface SqliteDriver extends SqliteExecutor {
|
||||
/**
|
||||
* Run `fn` inside a single write transaction. Commits when `fn` resolves,
|
||||
* rolls back and rethrows when it rejects — preserving the atomic-replace
|
||||
* Run `fn` inside a single write transaction. Commits when `fn` returns,
|
||||
* rolls back and rethrows when it throws — preserving the atomic-replace
|
||||
* semantics rebuildMaterial relies on (no mixed old/new rows ever visible).
|
||||
* `fn` MUST be synchronous — the better-sqlite3 backing implementation throws
|
||||
* if it returns a Promise (see BetterSqlite3Driver).
|
||||
*/
|
||||
transaction<T>(fn: (tx: SqliteTransaction) => Promise<T>): Promise<T>
|
||||
transaction<T>(fn: (tx: SqliteTransaction) => T): T
|
||||
/**
|
||||
* Return free space left by deletes to the OS: always checkpoint+truncate the
|
||||
* WAL, and VACUUM the main file when its freelist has grown large (both a big
|
||||
@@ -49,7 +61,7 @@ export interface SqliteDriver extends SqliteExecutor {
|
||||
* driver's writes; the VACUUM blocks the calling thread for the whole-file
|
||||
* rewrite, which is why the threshold gates it to large deletes.
|
||||
*/
|
||||
reclaim(preVacuumStatements?: readonly string[]): Promise<SqliteReclaimOutcome>
|
||||
reclaim(preVacuumStatements?: readonly string[]): SqliteReclaimOutcome
|
||||
/**
|
||||
* Whether {@link close} has been called. Lets a caller tell an operation that
|
||||
* failed because the store was closed mid-flight (concurrent base deletion or
|
||||
@@ -57,7 +69,7 @@ export interface SqliteDriver extends SqliteExecutor {
|
||||
* instead of leaking an opaque driver error.
|
||||
*/
|
||||
isClosed(): boolean
|
||||
close(): Promise<void>
|
||||
close(): void
|
||||
}
|
||||
|
||||
/** One brute-force vector match: an embedding row and its distance to the query. */
|
||||
|
||||
@@ -173,6 +173,16 @@ const MessageList = () => {
|
||||
messageListRef.current?.scrollToBottom('instant')
|
||||
}, [])
|
||||
|
||||
// Navigation buttons scroll through the virtua-aware runtime handle (smooth,
|
||||
// remeasure-safe) rather than a raw scrollTo on the virtualized scroller.
|
||||
const navigateToTop = useCallback(() => {
|
||||
messageListRef.current?.scrollToTop('smooth')
|
||||
}, [])
|
||||
|
||||
const navigateToBottom = useCallback(() => {
|
||||
messageListRef.current?.scrollToBottom('smooth')
|
||||
}, [])
|
||||
|
||||
const scrollToMessageById = useCallback((messageId: string) => {
|
||||
const target = messageByIdRef.current.get(messageId)
|
||||
if (!target) return
|
||||
@@ -559,7 +569,13 @@ const MessageList = () => {
|
||||
<MessageOutline message={activeOutlineMessage} multiModelMessageStyle={activeOutline.multiModelMessageStyle} />
|
||||
)}
|
||||
{messageNavigation === 'buttons' && (
|
||||
<MessageNavigation containerId="messages" messages={messages} scrollToMessageId={scrollToMessageById} />
|
||||
<MessageNavigation
|
||||
containerId="messages"
|
||||
messages={messages}
|
||||
scrollToMessageId={scrollToMessageById}
|
||||
scrollToTop={navigateToTop}
|
||||
scrollToBottom={navigateToBottom}
|
||||
/>
|
||||
)}
|
||||
{meta.selectionLayer && (
|
||||
<SelectionBox
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
} from '../types'
|
||||
|
||||
const scrollToBottom = vi.fn()
|
||||
const scrollToTop = vi.fn()
|
||||
const scrollToKey = vi.fn()
|
||||
const messageVirtualListMocks = vi.hoisted(() => ({
|
||||
deferScrollContainerReady: false,
|
||||
@@ -179,6 +180,7 @@ vi.mock('../list/MessageVirtualList', async () => {
|
||||
handleRef as Ref<MessageVirtualListHandle>,
|
||||
() => ({
|
||||
scrollToBottom,
|
||||
scrollToTop,
|
||||
scrollToKey,
|
||||
isAtBottom: () => false,
|
||||
getScrollElement: () => messageVirtualListMocks.scrollElement
|
||||
@@ -259,6 +261,7 @@ const renderMessageList = (messages: MessageListItem[]) =>
|
||||
describe('MessageList', () => {
|
||||
beforeEach(() => {
|
||||
scrollToBottom.mockClear()
|
||||
scrollToTop.mockClear()
|
||||
scrollToKey.mockClear()
|
||||
vi.mocked(captureScrollable).mockReset()
|
||||
vi.mocked(captureScrollableAsDataURL).mockReset()
|
||||
|
||||
@@ -30,24 +30,14 @@ describe('useMessageLeafCapabilities', () => {
|
||||
mockSafeOpen.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('loads external apps for ordinary text parts that mention inline absolute paths', () => {
|
||||
const partsByMessageId: Record<string, CherryMessagePart[]> = {
|
||||
message: [{ type: 'text', text: 'Open `/Users/example/project/App.tsx`.' } as CherryMessagePart]
|
||||
}
|
||||
|
||||
renderHook(() => useMessageLeafCapabilities({ partsByMessageId }))
|
||||
|
||||
expect(mockUseExternalApps).toHaveBeenCalledWith({ enabled: true })
|
||||
})
|
||||
|
||||
it('does not load external apps for text parts without local path hints', () => {
|
||||
it('loads external apps for the message list regardless of inline path hints', () => {
|
||||
const partsByMessageId: Record<string, CherryMessagePart[]> = {
|
||||
message: [{ type: 'text', text: 'plain response' } as CherryMessagePart]
|
||||
}
|
||||
|
||||
renderHook(() => useMessageLeafCapabilities({ partsByMessageId }))
|
||||
|
||||
expect(mockUseExternalApps).toHaveBeenCalledWith({ enabled: false })
|
||||
expect(mockUseExternalApps).toHaveBeenCalledWith()
|
||||
})
|
||||
|
||||
it('opens shared attachment files through safeOpen', async () => {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useQuery } from '@data/hooks/useDataApi'
|
||||
import type { MessageListActions, MessageListState } from '@renderer/components/chat/messages/types'
|
||||
import { containsInlineFilePath } from '@renderer/components/chat/messages/utils/filePath'
|
||||
import { useAttachment } from '@renderer/hooks/useAttachment'
|
||||
import { useExternalApps } from '@renderer/hooks/useExternalApps'
|
||||
import type { FileMetadata } from '@renderer/types/file'
|
||||
@@ -51,14 +50,6 @@ function isMcpToolPart(part: CherryMessagePart): boolean {
|
||||
return tool?.type === 'mcp'
|
||||
}
|
||||
|
||||
function hasExternalEditorPathHint(part: CherryMessagePart): boolean {
|
||||
const partType = (part as { type?: string }).type
|
||||
if (partType === 'dynamic-tool' || !!partType?.startsWith('tool-')) return true
|
||||
if (partType !== 'text') return false
|
||||
|
||||
return containsInlineFilePath((part as { text?: string }).text)
|
||||
}
|
||||
|
||||
function fileMetadataToHandle(file: FileMetadata): FileHandle {
|
||||
if (file.path) {
|
||||
try {
|
||||
@@ -111,12 +102,8 @@ export function useMessageLeafCapabilities({
|
||||
() => Object.values(partsByMessageId).some((parts) => parts.some(isMcpToolPart)),
|
||||
[partsByMessageId]
|
||||
)
|
||||
const hasExternalEditorPathHints = useMemo(
|
||||
() => Object.values(partsByMessageId).some((parts) => parts.some(hasExternalEditorPathHint)),
|
||||
[partsByMessageId]
|
||||
)
|
||||
const { data: mcpServersData } = useQuery('/mcp-servers', { enabled: hasMcpToolParts })
|
||||
const { data: externalApps } = useExternalApps({ enabled: hasExternalEditorPathHints })
|
||||
const { data: externalApps } = useExternalApps()
|
||||
const mcpServers = useMemo(() => mcpServersData?.items ?? [], [mcpServersData])
|
||||
const externalCodeEditors = useMemo(
|
||||
() => externalApps?.filter((app) => app.tags.includes('code-editor')) ?? [],
|
||||
|
||||
@@ -35,6 +35,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
|
||||
const userName = renderConfig.userName
|
||||
const assistantProfile = meta.assistantProfile
|
||||
const avatar = meta.userProfile?.avatar ?? ''
|
||||
const { updateMessageUiState } = actions
|
||||
const { setTimeoutTimer } = useTimer()
|
||||
|
||||
const messagesListRef = useRef<HTMLDivElement>(null)
|
||||
@@ -101,7 +102,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
|
||||
const groupMessages = messages.filter((m) => m.parentId === message.parentId)
|
||||
if (groupMessages.length > 1) {
|
||||
for (const m of groupMessages) {
|
||||
actions.updateMessageUiState?.(m.id, { foldSelected: m.id === message.id })
|
||||
updateMessageUiState?.(m.id, { foldSelected: m.id === message.id })
|
||||
}
|
||||
|
||||
setTimeoutTimer(
|
||||
@@ -116,7 +117,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
|
||||
)
|
||||
}
|
||||
},
|
||||
[actions.updateMessageUiState, messages, setTimeoutTimer]
|
||||
[messages, setTimeoutTimer, updateMessageUiState]
|
||||
)
|
||||
|
||||
const scrollToMessage = useCallback(
|
||||
@@ -125,7 +126,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
|
||||
const siblings = messages.filter((m) => m.role === 'assistant' && m.parentId === message.parentId)
|
||||
if (siblings.length > 1) {
|
||||
for (const sibling of siblings) {
|
||||
actions.updateMessageUiState?.(sibling.id, { foldSelected: sibling.id === message.id })
|
||||
updateMessageUiState?.(sibling.id, { foldSelected: sibling.id === message.id })
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,7 +147,7 @@ const MessageAnchorLine: FC<MessageLineProps> = ({
|
||||
}
|
||||
scrollIntoView(messageElement, { behavior: 'smooth', block: 'start', container: 'nearest' })
|
||||
},
|
||||
[actions, messages, scrollToMessageId, setSelectedMessage]
|
||||
[messages, scrollToMessageId, setSelectedMessage, updateMessageUiState]
|
||||
)
|
||||
|
||||
const scrollToBottom = useCallback(() => {
|
||||
@@ -328,14 +329,14 @@ const MessageLineContainer = ({
|
||||
<div
|
||||
ref={ref}
|
||||
className={[
|
||||
'group fixed right-[13px] z-999 flex w-[14px] translate-y-[-50%] select-none items-center justify-end overflow-hidden text-[5px] hover:w-[500px] hover:overflow-y-hidden hover:overflow-x-visible',
|
||||
'group absolute right-3.25 z-20 flex w-3.5 translate-y-[-50%] select-none items-center justify-end overflow-hidden text-[5px] hover:w-125 hover:overflow-y-hidden hover:overflow-x-visible',
|
||||
className
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(' ')}
|
||||
style={{
|
||||
top: 'calc(50% - var(--status-bar-height) - 10px)',
|
||||
maxHeight: $height ? `${$height - 20}px` : 'calc(100% - var(--status-bar-height) * 2 - 20px)',
|
||||
top: '50%',
|
||||
maxHeight: $height ? `${$height - 20}px` : 'calc(100% - 20px)',
|
||||
...style
|
||||
}}
|
||||
{...props}
|
||||
|
||||
@@ -23,13 +23,21 @@ interface MessageNavigationProps {
|
||||
containerId: string
|
||||
messages: MessageListItem[]
|
||||
scrollToMessageId: (messageId: string) => void
|
||||
scrollToTop: () => void
|
||||
scrollToBottom: () => void
|
||||
}
|
||||
|
||||
const getScrollContainer = (container: HTMLElement | null): HTMLElement | null => {
|
||||
return container?.querySelector<HTMLElement>('[data-message-virtual-list-scroller]') ?? container
|
||||
}
|
||||
|
||||
const MessageNavigation: FC<MessageNavigationProps> = ({ containerId, messages, scrollToMessageId }) => {
|
||||
const MessageNavigation: FC<MessageNavigationProps> = ({
|
||||
containerId,
|
||||
messages,
|
||||
scrollToMessageId,
|
||||
scrollToTop,
|
||||
scrollToBottom
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [isVisible, setIsVisible] = useState(false)
|
||||
const timerKey = 'hide'
|
||||
@@ -73,16 +81,6 @@ const MessageNavigation: FC<MessageNavigationProps> = ({ containerId, messages,
|
||||
scheduleHide(500)
|
||||
}, [scheduleHide])
|
||||
|
||||
const scrollToTop = () => {
|
||||
const scrollContainer = getScrollContainer(document.getElementById(containerId))
|
||||
scrollContainer?.scrollTo({ top: 0, behavior: 'smooth' })
|
||||
}
|
||||
|
||||
const scrollToBottom = () => {
|
||||
const scrollContainer = getScrollContainer(document.getElementById(containerId))
|
||||
scrollContainer?.scrollTo({ top: scrollContainer.scrollHeight, behavior: 'smooth' })
|
||||
}
|
||||
|
||||
const getCurrentVisibleIndex = (direction: 'up' | 'down') => {
|
||||
const userMessages = messages.filter((message) => message.role === 'user')
|
||||
const assistantMessages = messages.filter((message) => message.role === 'assistant')
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// @vitest-environment jsdom
|
||||
import '@testing-library/jest-dom/vitest'
|
||||
|
||||
import { render } from '@testing-library/react'
|
||||
import type { ReactNode } from 'react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type * as MessageTypes from '../../types'
|
||||
import type { MessageListItem } from '../../types'
|
||||
import MessageAnchorLine from '../MessageAnchorLine'
|
||||
|
||||
vi.mock('@cherrystudio/ui', () => ({
|
||||
Avatar: ({ children, ...props }: { children?: ReactNode }) => <div {...props}>{children}</div>,
|
||||
AvatarFallback: ({ children }: { children?: ReactNode }) => <span>{children}</span>,
|
||||
AvatarImage: () => null,
|
||||
EmojiAvatar: ({ children, ...props }: { children?: ReactNode }) => <span {...props}>{children}</span>
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useTheme', () => ({
|
||||
useTheme: () => ({ theme: 'light' })
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useTimer', () => ({
|
||||
useTimer: () => ({
|
||||
setTimeoutTimer: (_key: string, callback: () => void) => callback()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/model', () => ({
|
||||
getModelLogo: () => null
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/naming', () => ({
|
||||
firstLetter: (value: string) => value[0] ?? '',
|
||||
isEmoji: () => false,
|
||||
removeLeadingEmoji: (value: string) => value
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({ t: (key: string) => key })
|
||||
}))
|
||||
|
||||
vi.mock('../../blocks', () => ({
|
||||
usePartsMap: () => ({})
|
||||
}))
|
||||
|
||||
vi.mock('../../MessageListProvider', async () => {
|
||||
const { defaultMessageRenderConfig } = await vi.importActual<typeof MessageTypes>('../../types')
|
||||
|
||||
return {
|
||||
useMessageListActions: () => ({}),
|
||||
useMessageListMeta: () => ({}),
|
||||
useMessageRenderConfig: () => defaultMessageRenderConfig
|
||||
}
|
||||
})
|
||||
|
||||
const messages: MessageListItem[] = [
|
||||
{
|
||||
id: 'message-1',
|
||||
role: 'user',
|
||||
topicId: 'topic-1',
|
||||
parentId: null,
|
||||
createdAt: '2026-07-02T00:00:00.000Z',
|
||||
status: 'success'
|
||||
}
|
||||
]
|
||||
|
||||
describe('MessageAnchorLine', () => {
|
||||
it('keeps the anchor rail scoped inside the message list layer', () => {
|
||||
const { container } = render(<MessageAnchorLine messages={messages} />)
|
||||
|
||||
const anchorRail = container.firstElementChild
|
||||
expect(anchorRail).toHaveClass('absolute', 'z-20')
|
||||
expect(anchorRail).not.toHaveClass('fixed', 'z-999')
|
||||
})
|
||||
})
|
||||
@@ -50,6 +50,46 @@ const setRect = (element: Element, rect: Partial<DOMRect>) => {
|
||||
}))
|
||||
}
|
||||
|
||||
const renderNavigation = (messages: MessageListItem[], visibleMessageIds: string[] = []) => {
|
||||
const scrollToMessageId = vi.fn()
|
||||
const scrollToTop = vi.fn()
|
||||
const scrollToBottom = vi.fn()
|
||||
|
||||
const { container } = render(
|
||||
<>
|
||||
<div id="messages">
|
||||
<div data-message-virtual-list-scroller>
|
||||
{messages.map((message) => (
|
||||
<div key={message.id} id={`message-${message.id}`} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<MessageNavigation
|
||||
containerId="messages"
|
||||
messages={messages}
|
||||
scrollToMessageId={scrollToMessageId}
|
||||
scrollToTop={scrollToTop}
|
||||
scrollToBottom={scrollToBottom}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
setRect(container.querySelector('[data-message-virtual-list-scroller]') as HTMLElement, {
|
||||
bottom: 500,
|
||||
height: 500,
|
||||
top: 0
|
||||
})
|
||||
for (const message of messages) {
|
||||
setRect(document.getElementById(`message-${message.id}`) as HTMLElement, {
|
||||
bottom: visibleMessageIds.includes(message.id) ? 220 : -100,
|
||||
height: 100,
|
||||
top: visibleMessageIds.includes(message.id) ? 120 : -200
|
||||
})
|
||||
}
|
||||
|
||||
return { scrollToBottom, scrollToMessageId, scrollToTop }
|
||||
}
|
||||
|
||||
describe('MessageNavigation', () => {
|
||||
it('scrolls to message ids from the full message list, not only rendered DOM nodes', () => {
|
||||
const scrollToMessageId = vi.fn()
|
||||
@@ -68,7 +108,13 @@ describe('MessageNavigation', () => {
|
||||
<div id="message-user-2" />
|
||||
</div>
|
||||
</div>
|
||||
<MessageNavigation containerId="messages" messages={messages} scrollToMessageId={scrollToMessageId} />
|
||||
<MessageNavigation
|
||||
containerId="messages"
|
||||
messages={messages}
|
||||
scrollToMessageId={scrollToMessageId}
|
||||
scrollToTop={vi.fn()}
|
||||
scrollToBottom={vi.fn()}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
@@ -87,4 +133,78 @@ describe('MessageNavigation', () => {
|
||||
|
||||
expect(scrollToMessageId).toHaveBeenCalledWith('user-3')
|
||||
})
|
||||
|
||||
it('delegates the top and bottom buttons to the runtime scroll callbacks', () => {
|
||||
const scrollToTop = vi.fn()
|
||||
const scrollToBottom = vi.fn()
|
||||
|
||||
render(
|
||||
<>
|
||||
<div id="messages">
|
||||
<div data-message-virtual-list-scroller />
|
||||
</div>
|
||||
<MessageNavigation
|
||||
containerId="messages"
|
||||
messages={[createMessage('user-1', 'user')]}
|
||||
scrollToMessageId={vi.fn()}
|
||||
scrollToTop={scrollToTop}
|
||||
scrollToBottom={scrollToBottom}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.top' }))
|
||||
expect(scrollToTop).toHaveBeenCalledTimes(1)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.bottom' }))
|
||||
expect(scrollToBottom).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it.each([
|
||||
{
|
||||
name: 'when there are no messages',
|
||||
messages: []
|
||||
},
|
||||
{
|
||||
name: 'when no message is visible',
|
||||
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')]
|
||||
},
|
||||
{
|
||||
name: 'when the first user message is already visible',
|
||||
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')],
|
||||
visibleMessageIds: ['user-1']
|
||||
}
|
||||
])('delegates next-message fallback to runtime scrollToBottom $name', ({ messages, visibleMessageIds }) => {
|
||||
const { scrollToBottom, scrollToMessageId, scrollToTop } = renderNavigation(messages, visibleMessageIds)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.next' }))
|
||||
|
||||
expect(scrollToBottom).toHaveBeenCalledTimes(1)
|
||||
expect(scrollToTop).not.toHaveBeenCalled()
|
||||
expect(scrollToMessageId).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it.each([
|
||||
{
|
||||
name: 'when there are no messages',
|
||||
messages: []
|
||||
},
|
||||
{
|
||||
name: 'when no message is visible',
|
||||
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')]
|
||||
},
|
||||
{
|
||||
name: 'when the last user message is already visible',
|
||||
messages: [createMessage('user-1', 'user'), createMessage('user-2', 'user')],
|
||||
visibleMessageIds: ['user-2']
|
||||
}
|
||||
])('delegates prev-message fallback to runtime scrollToTop $name', ({ messages, visibleMessageIds }) => {
|
||||
const { scrollToBottom, scrollToMessageId, scrollToTop } = renderNavigation(messages, visibleMessageIds)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'chat.navigation.prev' }))
|
||||
|
||||
expect(scrollToTop).toHaveBeenCalledTimes(1)
|
||||
expect(scrollToBottom).not.toHaveBeenCalled()
|
||||
expect(scrollToMessageId).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -443,6 +443,202 @@ describe('useChatVirtualizerRuntime', () => {
|
||||
expect(runtime!.isScrollToBottomButtonVisible).toBe(false)
|
||||
})
|
||||
|
||||
it('scrolls to top instantly and releases the top anchor', () => {
|
||||
const callbacks: ResizeObserverCallback[] = []
|
||||
const restoreResizeObserver = installResizeObserverMock(callbacks)
|
||||
const raf = installQueuedAnimationFrame()
|
||||
|
||||
try {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
|
||||
handle = nextHandle
|
||||
}
|
||||
let scrollTop = 0
|
||||
const view = render(
|
||||
<RuntimeDomProbe
|
||||
items={['message-a']}
|
||||
handleRef={handleRef}
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
const scroller = runtime!.scrollerRef.current!
|
||||
Object.defineProperty(scroller, 'scrollTop', {
|
||||
configurable: true,
|
||||
get: () => scrollTop,
|
||||
set: (value) => {
|
||||
scrollTop = value
|
||||
}
|
||||
})
|
||||
Object.defineProperty(scroller, 'scrollHeight', { configurable: true, get: () => 700 })
|
||||
Object.defineProperty(scroller, 'clientHeight', { configurable: true, get: () => 400 })
|
||||
runtime!.vlistHandleRef.current = createHandle({ getItemOffset: vi.fn(() => 300) })
|
||||
|
||||
view.rerender(
|
||||
<RuntimeDomProbe
|
||||
items={['message-a']}
|
||||
handleRef={handleRef}
|
||||
preserveScrollAnchor
|
||||
scrollToTopKey="message-a"
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
raf.tick()
|
||||
expect(runtime!.wrappedItems.some((item) => item.kind === 'spacer')).toBe(true)
|
||||
|
||||
act(() => {
|
||||
scrollTop = 300
|
||||
handle!.scrollToTop('instant')
|
||||
})
|
||||
expect(scrollTop).toBe(0)
|
||||
|
||||
act(() => callbacks[0]?.([], {} as ResizeObserver))
|
||||
|
||||
expect(scrollTop).toBe(0)
|
||||
} finally {
|
||||
restoreResizeObserver()
|
||||
raf.restore()
|
||||
}
|
||||
})
|
||||
|
||||
it('scrolls to top smoothly with the RAF-driven runtime scroller', () => {
|
||||
const raf = installQueuedAnimationFrame()
|
||||
|
||||
try {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
|
||||
handle = nextHandle
|
||||
}
|
||||
let scrollTop = 500
|
||||
render(
|
||||
<RuntimeDomProbe
|
||||
items={['message-a']}
|
||||
handleRef={handleRef}
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
const scroller = runtime!.scrollerRef.current!
|
||||
Object.defineProperty(scroller, 'scrollTop', {
|
||||
configurable: true,
|
||||
get: () => scrollTop,
|
||||
set: (value) => {
|
||||
scrollTop = value
|
||||
}
|
||||
})
|
||||
|
||||
act(() => {
|
||||
handle!.scrollToTop('smooth')
|
||||
})
|
||||
expect(scrollTop).toBe(500)
|
||||
|
||||
raf.tick()
|
||||
expect(scrollTop).toBeGreaterThan(0)
|
||||
expect(scrollTop).toBeLessThan(500)
|
||||
|
||||
raf.tick(50)
|
||||
expect(scrollTop).toBe(0)
|
||||
} finally {
|
||||
raf.restore()
|
||||
}
|
||||
})
|
||||
|
||||
it('replaces an in-flight smooth scroll when scrolling to top', () => {
|
||||
const raf = installQueuedAnimationFrame()
|
||||
|
||||
try {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
|
||||
handle = nextHandle
|
||||
}
|
||||
let scrollTop = 0
|
||||
render(
|
||||
<RuntimeDomProbe
|
||||
items={['message-a']}
|
||||
handleRef={handleRef}
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
const scroller = runtime!.scrollerRef.current!
|
||||
Object.defineProperty(scroller, 'scrollTop', {
|
||||
configurable: true,
|
||||
get: () => scrollTop,
|
||||
set: (value) => {
|
||||
scrollTop = value
|
||||
}
|
||||
})
|
||||
setElementMetric(scroller, 'scrollHeight', () => 1200)
|
||||
setElementMetric(scroller, 'clientHeight', () => 400)
|
||||
|
||||
act(() => {
|
||||
handle!.scrollToBottom('smooth')
|
||||
})
|
||||
raf.tick()
|
||||
const bottomScrollProgress = scrollTop
|
||||
expect(bottomScrollProgress).toBeGreaterThan(0)
|
||||
|
||||
act(() => {
|
||||
handle!.scrollToTop('smooth')
|
||||
})
|
||||
raf.tick()
|
||||
expect(scrollTop).toBeLessThan(bottomScrollProgress)
|
||||
|
||||
raf.tick(50)
|
||||
expect(scrollTop).toBe(0)
|
||||
} finally {
|
||||
raf.restore()
|
||||
}
|
||||
})
|
||||
|
||||
it('replaces an in-flight smooth scroll when scrolling to bottom', () => {
|
||||
const raf = installQueuedAnimationFrame()
|
||||
|
||||
try {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
const handleRef: Ref<MessageVirtualListHandle> = (nextHandle) => {
|
||||
handle = nextHandle
|
||||
}
|
||||
let scrollTop = 800
|
||||
render(
|
||||
<RuntimeDomProbe
|
||||
items={['message-a']}
|
||||
handleRef={handleRef}
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
const scroller = runtime!.scrollerRef.current!
|
||||
Object.defineProperty(scroller, 'scrollTop', {
|
||||
configurable: true,
|
||||
get: () => scrollTop,
|
||||
set: (value) => {
|
||||
scrollTop = value
|
||||
}
|
||||
})
|
||||
setElementMetric(scroller, 'scrollHeight', () => 1200)
|
||||
setElementMetric(scroller, 'clientHeight', () => 400)
|
||||
|
||||
act(() => {
|
||||
handle!.scrollToTop('smooth')
|
||||
})
|
||||
raf.tick()
|
||||
const topScrollProgress = scrollTop
|
||||
expect(topScrollProgress).toBeLessThan(800)
|
||||
|
||||
act(() => {
|
||||
handle!.scrollToBottom('smooth')
|
||||
})
|
||||
raf.tick()
|
||||
expect(scrollTop).toBeGreaterThan(topScrollProgress)
|
||||
|
||||
raf.tick(50)
|
||||
expect(scrollTop).toBe(800)
|
||||
} finally {
|
||||
raf.restore()
|
||||
}
|
||||
})
|
||||
|
||||
it('resets bottom-follow state when pinning a message to the viewport top', () => {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
@@ -479,6 +675,46 @@ describe('useChatVirtualizerRuntime', () => {
|
||||
expect(handle!.isAtBottom()).toBe(false)
|
||||
})
|
||||
|
||||
it('does not pin a follow-up steered into a still-streaming turn', () => {
|
||||
const callbacks: ResizeObserverCallback[] = []
|
||||
const restoreResizeObserver = installResizeObserverMock(callbacks)
|
||||
const raf = installQueuedAnimationFrame()
|
||||
|
||||
try {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
// Render 1: a turn is already streaming (preserveScrollAnchor is true) and no
|
||||
// new user-message key has arrived yet.
|
||||
const view = render(
|
||||
<RuntimeDomProbe items={['user-a']} preserveScrollAnchor onRuntime={(nextRuntime) => (runtime = nextRuntime)} />
|
||||
)
|
||||
const getSpacerHeight = () => runtime!.wrappedItems.find((item) => item.kind === 'spacer')?.height ?? 0
|
||||
const scroller = runtime!.scrollerRef.current!
|
||||
Object.defineProperty(scroller, 'scrollTop', { configurable: true, get: () => 0 })
|
||||
Object.defineProperty(scroller, 'scrollHeight', { configurable: true, get: () => 900 + getSpacerHeight() })
|
||||
Object.defineProperty(scroller, 'clientHeight', { configurable: true, get: () => 400 })
|
||||
runtime!.vlistHandleRef.current = createHandle({ getItemOffset: vi.fn(() => 300) })
|
||||
|
||||
// Render 2: a queued follow-up is steered into the live turn — a new user
|
||||
// message (`user-b`) arrives while streaming continues. Because a turn was
|
||||
// already streaming just before it, the message must NOT pin to the top, so
|
||||
// no anchor spacer is created (the pin path is what created the instability).
|
||||
view.rerender(
|
||||
<RuntimeDomProbe
|
||||
items={['user-a', 'user-b']}
|
||||
preserveScrollAnchor
|
||||
scrollToTopKey="user-b"
|
||||
onRuntime={(nextRuntime) => (runtime = nextRuntime)}
|
||||
/>
|
||||
)
|
||||
raf.tick()
|
||||
|
||||
expect(getSpacerHeight()).toBe(0)
|
||||
} finally {
|
||||
restoreResizeObserver()
|
||||
raf.restore()
|
||||
}
|
||||
})
|
||||
|
||||
it('keeps bottom-follow suppressed while the user is still pinned to the top', () => {
|
||||
let runtime: ChatVirtualizerRuntime<string> | undefined
|
||||
let handle: MessageVirtualListHandle | null = null
|
||||
|
||||
@@ -36,6 +36,7 @@ import { useSmoothScrollAnimation } from './useSmoothScrollAnimation'
|
||||
|
||||
export interface MessageVirtualListHandle {
|
||||
scrollToBottom(behavior?: ScrollBehavior): void
|
||||
scrollToTop(behavior?: ScrollBehavior): void
|
||||
scrollToKey(key: string, align?: 'start' | 'center' | 'end'): void
|
||||
isAtBottom(): boolean
|
||||
getScrollElement(): HTMLElement | null
|
||||
@@ -402,6 +403,11 @@ export function useChatVirtualizerRuntime<T>({
|
||||
|
||||
const lastScrollToTopKeyRef = useRef<string | undefined>(undefined)
|
||||
const didMountForScrollKeyRef = useRef(false)
|
||||
// The committed `preserveScrollAnchor` from the previous render — i.e. whether a
|
||||
// turn was already streaming just before the current commit. Lets the pin effect
|
||||
// tell a fresh idle→new-turn send from a mid-stream insertion. A trailing effect
|
||||
// (below) keeps it in sync AFTER the pin effect has read the prior value.
|
||||
const wasStreamingBeforeUserMessageRef = useRef(preserveScrollAnchor)
|
||||
|
||||
useEffect(() => {
|
||||
const previous = lastScrollToTopKeyRef.current
|
||||
@@ -411,6 +417,12 @@ export function useChatVirtualizerRuntime<T>({
|
||||
return
|
||||
}
|
||||
if (!scrollToTopKey || scrollToTopKey === previous) return
|
||||
// A new user message appeared. Only pin it to the top when it STARTS a fresh
|
||||
// turn (the topic was idle just before it). If a turn was already streaming —
|
||||
// a queued follow-up steered into the live turn — pinning the new message to
|
||||
// the top would yank the view and fight the previous assistant's still-growing
|
||||
// response (the instability we're fixing). Leave scroll to bottom-follow.
|
||||
if (wasStreamingBeforeUserMessageRef.current) return
|
||||
const idx = findDataIndexByKey(scrollToTopKey)
|
||||
if (idx < 0) return
|
||||
anchor.pinTo(idx)
|
||||
@@ -420,6 +432,13 @@ export function useChatVirtualizerRuntime<T>({
|
||||
userTookControlRef.current = false
|
||||
}, [anchor, atBottom, findDataIndexByKey, scrollToTopKey])
|
||||
|
||||
// Sync the "was a turn already streaming" marker AFTER the pin effect above has
|
||||
// read the previous render's value. Runs every commit so the next new-user-
|
||||
// message commit sees whether streaming was in progress when it arrived.
|
||||
useEffect(() => {
|
||||
wasStreamingBeforeUserMessageRef.current = preserveScrollAnchor
|
||||
})
|
||||
|
||||
// Initial scroll on mount is owned by `useScrollPositionMemory` above: it
|
||||
// restores the saved anchor for this topic, or scrolls to the newest message
|
||||
// when there is nothing to restore.
|
||||
@@ -566,12 +585,10 @@ export function useChatVirtualizerRuntime<T>({
|
||||
if (!el) return
|
||||
const target = getRealBottom(el, anchor.spacerHeight)
|
||||
if (behavior === 'smooth') {
|
||||
if (!smoothScroll.isAnimating()) {
|
||||
smoothScroll.scrollTo(() => {
|
||||
const current = scrollerRef.current
|
||||
return current ? getRealBottom(current, bottomFollowInsetRef.current) : 0
|
||||
})
|
||||
}
|
||||
smoothScroll.scrollTo(() => {
|
||||
const current = scrollerRef.current
|
||||
return current ? getRealBottom(current, bottomFollowInsetRef.current) : 0
|
||||
})
|
||||
} else {
|
||||
smoothScroll.cancel()
|
||||
el.scrollTop = target
|
||||
@@ -582,10 +599,31 @@ export function useChatVirtualizerRuntime<T>({
|
||||
[anchor, atBottom, hideScrollToBottomButton, smoothScroll]
|
||||
)
|
||||
|
||||
const scrollToTop = useCallback(
|
||||
(behavior: ScrollBehavior = 'instant') => {
|
||||
// Explicit scroll-to-top releases any anchor pin — the caller wants the
|
||||
// absolute top of the loaded content, not the pinned user-message position.
|
||||
anchor.release()
|
||||
const el = scrollerRef.current
|
||||
if (!el) return
|
||||
if (behavior === 'smooth') {
|
||||
// Drive the scroll frame-by-frame (RAF) rather than native
|
||||
// `behavior: 'smooth'`: virtua remeasures items entering the viewport
|
||||
// and compensates scrollTop, which cancels a native animation mid-flight.
|
||||
smoothScroll.scrollTo(() => 0)
|
||||
} else {
|
||||
smoothScroll.cancel()
|
||||
el.scrollTop = 0
|
||||
}
|
||||
},
|
||||
[anchor, smoothScroll]
|
||||
)
|
||||
|
||||
useImperativeHandle(
|
||||
handleRef,
|
||||
(): MessageVirtualListHandle => ({
|
||||
scrollToBottom,
|
||||
scrollToTop,
|
||||
scrollToKey: (key, align = 'start') => {
|
||||
const handle = vlistHandleRef.current
|
||||
const idx = findDataIndexByKey(key)
|
||||
@@ -596,7 +634,7 @@ export function useChatVirtualizerRuntime<T>({
|
||||
isAtBottom: atBottom.isAtBottom,
|
||||
getScrollElement: () => scrollerRef.current
|
||||
}),
|
||||
[anchor, atBottom.isAtBottom, findDataIndexByKey, scrollToBottom]
|
||||
[anchor, atBottom.isAtBottom, findDataIndexByKey, scrollToBottom, scrollToTop]
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@@ -43,9 +43,3 @@ export function isInlineFilePath(value: string): boolean {
|
||||
WORKSPACE_RELATIVE_FILE_PATH_PATTERN.test(normalizedPath)
|
||||
)
|
||||
}
|
||||
|
||||
export function containsInlineFilePath(value: string | undefined): boolean {
|
||||
if (!value) return false
|
||||
|
||||
return value.split(/\s+/).some((token) => isInlineFilePath(token))
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { usePreference } from '@data/hooks/usePreference'
|
||||
import { loggerService } from '@logger'
|
||||
import type { ResolvedAction } from '@renderer/components/chat/actions/actionTypes'
|
||||
import {
|
||||
@@ -17,7 +18,7 @@ import { usePins } from '@renderer/hooks/usePins'
|
||||
import { mapApiTopicToRendererTopic, useTopicMutations } from '@renderer/hooks/useTopic'
|
||||
import type { Topic } from '@renderer/types/topic'
|
||||
import { formatErrorMessageWithPrefix } from '@renderer/utils/error'
|
||||
import { Bot, Edit3, PinIcon, PinOffIcon, Plus, Trash2 } from 'lucide-react'
|
||||
import { Bot, Edit3, PinIcon, PinOffIcon, Plus, Tags, Trash2 } from 'lucide-react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
@@ -27,6 +28,7 @@ const logger = loggerService.withContext('AssistantResourceList')
|
||||
|
||||
const ASSISTANT_ENTITY_EDIT_ACTION_ID = 'assistant-entity.edit'
|
||||
const ASSISTANT_ENTITY_TOGGLE_PIN_ACTION_ID = 'assistant-entity.toggle-pin'
|
||||
const ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID = 'assistant-entity.toggle-tag-grouping'
|
||||
const ASSISTANT_ENTITY_DELETE_ACTION_ID = 'assistant-entity.delete'
|
||||
|
||||
type AssistantResourceListProps = {
|
||||
@@ -52,6 +54,8 @@ export function AssistantResourceList({
|
||||
onActiveAssistantDeleted
|
||||
}: AssistantResourceListProps) {
|
||||
const { t } = useTranslation()
|
||||
const [assistantSortType, setAssistantSortType] = usePreference('assistant.tab.sort_type')
|
||||
const isTagGrouping = assistantSortType === 'tags'
|
||||
const {
|
||||
assistants,
|
||||
isLoading: isAssistantsLoading,
|
||||
@@ -95,6 +99,7 @@ export function AssistantResourceList({
|
||||
name: assistant.name,
|
||||
orderKey: assistant.orderKey,
|
||||
pinned: assistantPinnedIdSet.has(assistant.id),
|
||||
tag: assistant.tags?.[0]?.name,
|
||||
icon: assistant.emoji ? (
|
||||
<EmojiIcon emoji={assistant.emoji} size={24} fontSize={14} className="mr-0" />
|
||||
) : (
|
||||
@@ -227,6 +232,15 @@ export function AssistantResourceList({
|
||||
availability: { visible: true, enabled: !isAssistantPinActionDisabled },
|
||||
children: []
|
||||
},
|
||||
{
|
||||
id: ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID,
|
||||
label: isTagGrouping ? t('assistants.tags.ungroup') : t('assistants.tags.group_by'),
|
||||
icon: <Tags size={14} />,
|
||||
order: 25,
|
||||
danger: false,
|
||||
availability: { visible: true, enabled: true },
|
||||
children: []
|
||||
},
|
||||
{
|
||||
id: ASSISTANT_ENTITY_DELETE_ACTION_ID,
|
||||
label: t('assistants.delete.title'),
|
||||
@@ -239,7 +253,7 @@ export function AssistantResourceList({
|
||||
}
|
||||
]
|
||||
},
|
||||
[assistantPinnedIdSet, deletingAssistantId, isAssistantPinActionDisabled, t]
|
||||
[assistantPinnedIdSet, deletingAssistantId, isAssistantPinActionDisabled, isTagGrouping, t]
|
||||
)
|
||||
|
||||
const handleContextMenuAction = useCallback(
|
||||
@@ -252,11 +266,15 @@ export function AssistantResourceList({
|
||||
void handleToggleAssistantPin(item.id)
|
||||
return
|
||||
}
|
||||
if (action.id === ASSISTANT_ENTITY_TOGGLE_TAG_GROUPING_ACTION_ID) {
|
||||
void setAssistantSortType(isTagGrouping ? 'list' : 'tags')
|
||||
return
|
||||
}
|
||||
if (action.id === ASSISTANT_ENTITY_DELETE_ACTION_ID) {
|
||||
void handleDeleteAssistant(item.id)
|
||||
}
|
||||
},
|
||||
[handleDeleteAssistant, handleToggleAssistantPin, openAssistantEditor]
|
||||
[handleDeleteAssistant, handleToggleAssistantPin, isTagGrouping, openAssistantEditor, setAssistantSortType]
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -268,6 +286,7 @@ export function AssistantResourceList({
|
||||
status={listStatus}
|
||||
ariaLabel={t('assistants.abbr')}
|
||||
defaultGroupLabel={t('assistants.abbr')}
|
||||
groupByTag={isTagGrouping}
|
||||
addIcon={<Plus />}
|
||||
addLabel={t('chat.add.assistant.title')}
|
||||
onAdd={onAddAssistant ?? (() => onStartDraftAssistant(null))}
|
||||
|
||||
@@ -26,6 +26,8 @@ export type ResourceEntityRailItem = {
|
||||
* It does not affect visibility — an entity with no resources stays hidden whether pinned or not.
|
||||
*/
|
||||
pinned?: boolean
|
||||
/** Single user tag name. Only consulted when the rail runs with `groupByTag`; undefined → "未分组". */
|
||||
tag?: string
|
||||
}
|
||||
|
||||
// Pinned entities float into a "已固定" section at the top; the rest sit under the "助手" / "智能体"
|
||||
@@ -36,6 +38,27 @@ const ENTITY_RAIL_PINNED_SECTION_ID = 'resource-entity-rail:section:pinned'
|
||||
const ENTITY_RAIL_DEFAULT_SECTION_ID = 'resource-entity-rail:section:default'
|
||||
const ENTITY_RAIL_PINNED_GROUP_ID = 'resource-entity-rail:group:pinned'
|
||||
const ENTITY_RAIL_DEFAULT_GROUP_ID = 'resource-entity-rail:group:default'
|
||||
// When `groupByTag` is on, each tag name becomes its own collapsible section below the pinned one;
|
||||
// untagged entities collapse together under a distinct internal bucket.
|
||||
const ENTITY_RAIL_TAG_SECTION_PREFIX = 'resource-entity-rail:section:'
|
||||
const ENTITY_RAIL_TAG_GROUP_PREFIX = 'resource-entity-rail:group:'
|
||||
const ENTITY_RAIL_UNTAGGED_KEY = JSON.stringify(['untagged'])
|
||||
|
||||
function getEntityRailTagBucketKey(tag: string | undefined) {
|
||||
return tag ? JSON.stringify(['tag', tag]) : ENTITY_RAIL_UNTAGGED_KEY
|
||||
}
|
||||
|
||||
function getEntityRailTagGroupingRank(item: ResourceEntityRailItem) {
|
||||
if (item.pinned) return 0
|
||||
return item.tag ? 2 : 1
|
||||
}
|
||||
|
||||
function sortEntityRailItemsForTagGrouping<T extends ResourceEntityRailItem>(items: readonly T[]): T[] {
|
||||
return items
|
||||
.map((item, index) => ({ item, index, rank: getEntityRailTagGroupingRank(item) }))
|
||||
.sort((a, b) => a.rank - b.rank || a.index - b.index)
|
||||
.map(({ item }) => item)
|
||||
}
|
||||
|
||||
export type ResourceEntityRailProps<T extends ResourceEntityRailItem, TActionContext = unknown> = {
|
||||
addIcon?: ReactNode
|
||||
@@ -43,6 +66,12 @@ export type ResourceEntityRailProps<T extends ResourceEntityRailItem, TActionCon
|
||||
ariaLabel: string
|
||||
/** Header for the non-pinned group ("助手" for assistants, "智能体" for agents). */
|
||||
defaultGroupLabel?: string
|
||||
/**
|
||||
* Group the non-pinned entities by their `tag` into collapsible sections (the pinned section stays
|
||||
* on top). Drag-reorder is disabled while on, since `orderKey` is a single flat order. Off → the
|
||||
* flat "助手"/"智能体" section.
|
||||
*/
|
||||
groupByTag?: boolean
|
||||
emptyFallback?: ReactNode
|
||||
getContextMenuActions?: (item: T) => readonly ResolvedAction<TActionContext>[]
|
||||
listRef?: RefObject<HTMLDivElement | null>
|
||||
@@ -82,6 +111,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
addLabel,
|
||||
ariaLabel,
|
||||
defaultGroupLabel,
|
||||
groupByTag = false,
|
||||
emptyFallback,
|
||||
getContextMenuActions,
|
||||
listRef,
|
||||
@@ -96,6 +126,9 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
items
|
||||
}: ResourceEntityRailProps<T, TActionContext>) {
|
||||
const { t } = useTranslation()
|
||||
// Tag grouping splits the flat order across sections, so dragging an item between tags would have
|
||||
// no meaningful `orderKey` target — disable reorder entirely while grouping by tag.
|
||||
const reorderEnabled = !!onReorder && !groupByTag
|
||||
const fallbackListRef = useRef<HTMLDivElement>(null)
|
||||
const effectiveListRef = listRef ?? fallbackListRef
|
||||
const runContextMenuAction = useCallback(
|
||||
@@ -173,23 +206,39 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
[getContextMenuActions, onContextMenuAction, onSelect, runContextMenuAction, t]
|
||||
)
|
||||
const empty = useMemo(() => emptyFallback ?? <div className="min-h-0 flex-1" />, [emptyFallback])
|
||||
const providerItems = useMemo(
|
||||
() => (groupByTag ? sortEntityRailItemsForTagGrouping(items) : items),
|
||||
[groupByTag, items]
|
||||
)
|
||||
// Collapsible sections matching the modern layout's left assistant/agent layout (minus the nested
|
||||
// topics/sessions): pinned entities float into "已固定" at the top, the rest sit under the
|
||||
// "助手" / "智能体" section below. Section headers stay flush-left; the entity rows keep their
|
||||
// avatar and read as indented beneath. The single-section case (nothing pinned) renders the flat
|
||||
// list with no header, exactly like the modern layout.
|
||||
const sectionBy = useMemo<(item: T) => ResourceListSection>(
|
||||
() => (item) =>
|
||||
item.pinned
|
||||
? { id: ENTITY_RAIL_PINNED_SECTION_ID, label: t('selector.common.pinned_title') }
|
||||
: { id: ENTITY_RAIL_DEFAULT_SECTION_ID, label: defaultGroupLabel ?? '' },
|
||||
[defaultGroupLabel, t]
|
||||
() => (item) => {
|
||||
if (item.pinned) return { id: ENTITY_RAIL_PINNED_SECTION_ID, label: t('selector.common.pinned_title') }
|
||||
if (groupByTag) {
|
||||
const tagBucketKey = getEntityRailTagBucketKey(item.tag)
|
||||
return item.tag
|
||||
? { id: `${ENTITY_RAIL_TAG_SECTION_PREFIX}${tagBucketKey}`, label: item.tag }
|
||||
: { id: `${ENTITY_RAIL_TAG_SECTION_PREFIX}${tagBucketKey}`, label: t('assistants.tags.untagged') }
|
||||
}
|
||||
return { id: ENTITY_RAIL_DEFAULT_SECTION_ID, label: defaultGroupLabel ?? '' }
|
||||
},
|
||||
[defaultGroupLabel, groupByTag, t]
|
||||
)
|
||||
// Header-less groups (one per section, distinct ids) keep entity avatars visible and stop
|
||||
// drag-reorder from crossing the pinned/non-pinned boundary.
|
||||
// drag-reorder from crossing the pinned/non-pinned (or per-tag) boundary.
|
||||
const groupBy = useMemo<(item: T) => ResourceListGroup>(
|
||||
() => (item) => ({ id: item.pinned ? ENTITY_RAIL_PINNED_GROUP_ID : ENTITY_RAIL_DEFAULT_GROUP_ID, label: '' }),
|
||||
[]
|
||||
() => (item) => {
|
||||
if (item.pinned) return { id: ENTITY_RAIL_PINNED_GROUP_ID, label: '' }
|
||||
if (groupByTag) {
|
||||
return { id: `${ENTITY_RAIL_TAG_GROUP_PREFIX}${getEntityRailTagBucketKey(item.tag)}`, label: '' }
|
||||
}
|
||||
return { id: ENTITY_RAIL_DEFAULT_GROUP_ID, label: '' }
|
||||
},
|
||||
[groupByTag]
|
||||
)
|
||||
|
||||
// Alias the compound provider to a local before rendering — same pattern as TopicResourceList/SessionResourceList.
|
||||
@@ -200,7 +249,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
return (
|
||||
<Provider
|
||||
variant={variant}
|
||||
items={items}
|
||||
items={providerItems}
|
||||
selectedId={selectedId}
|
||||
status={status}
|
||||
groupBy={groupBy}
|
||||
@@ -208,15 +257,15 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
defaultGroupVisibleCount={Number.POSITIVE_INFINITY}
|
||||
dragCapabilities={{
|
||||
groups: false,
|
||||
items: !!onReorder,
|
||||
itemSameGroup: !!onReorder,
|
||||
items: reorderEnabled,
|
||||
itemSameGroup: reorderEnabled,
|
||||
itemCrossGroup: false
|
||||
}}
|
||||
canDragItem={({ item }) => !!onReorder && !item.pinned}
|
||||
canDragItem={({ item }) => reorderEnabled && !item.pinned}
|
||||
canDropItem={({ activeItem, targetGroupId }) =>
|
||||
!!onReorder && !activeItem.pinned && targetGroupId !== ENTITY_RAIL_PINNED_GROUP_ID
|
||||
reorderEnabled && !activeItem.pinned && targetGroupId !== ENTITY_RAIL_PINNED_GROUP_ID
|
||||
}
|
||||
onReorder={onReorder}>
|
||||
onReorder={reorderEnabled ? onReorder : undefined}>
|
||||
<ResourceList.Frame className="h-full min-h-0" data-testid={`${variant}-entity-rail`}>
|
||||
<ResourceList.Header className="gap-1">
|
||||
<ResourceList.HeaderItem
|
||||
@@ -241,7 +290,7 @@ export function ResourceEntityRail<T extends ResourceEntityRailItem, TActionCont
|
||||
</ResourceList.Header>
|
||||
<ResourceList.Body<T>
|
||||
listRef={effectiveListRef}
|
||||
draggable={!!onReorder}
|
||||
draggable={reorderEnabled}
|
||||
ariaLabel={ariaLabel}
|
||||
virtualClassName="pt-1 pb-3"
|
||||
errorFallback={<ResourceList.ErrorState message={t('error.boundary.default.message')} />}
|
||||
|
||||
@@ -17,12 +17,21 @@ const agentDataMocks = vi.hoisted(() => ({
|
||||
refetchAgents: vi.fn()
|
||||
}))
|
||||
|
||||
const preferenceMocks = vi.hoisted(() => ({
|
||||
sortType: 'list' as 'list' | 'tags',
|
||||
setSortType: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@data/hooks/usePreference', () => ({
|
||||
usePreference: () => [preferenceMocks.sortType, preferenceMocks.setSortType]
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
@@ -192,6 +201,8 @@ vi.mock('@renderer/utils/error', () => ({
|
||||
|
||||
describe('classic layout entity resource list actions', () => {
|
||||
beforeEach(() => {
|
||||
preferenceMocks.sortType = 'list'
|
||||
preferenceMocks.setSortType.mockClear()
|
||||
assistantDataMocks.deleteAssistant.mockResolvedValue(undefined)
|
||||
assistantDataMocks.refreshTopics.mockResolvedValue(undefined)
|
||||
assistantDataMocks.refetchAssistants.mockResolvedValue(undefined)
|
||||
@@ -239,6 +250,33 @@ describe('classic layout entity resource list actions', () => {
|
||||
expect(onStartDraftAssistant).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('toggles assistant tag grouping from the context menu (list → tags)', () => {
|
||||
render(
|
||||
<AssistantResourceList activeAssistantId="assistant-1" onSelectTopic={vi.fn()} onStartDraftAssistant={vi.fn()} />
|
||||
)
|
||||
|
||||
// sort_type === 'list' → the menu offers "group by tag".
|
||||
const menu = screen.getByTestId('assistant-1-context-menu')
|
||||
expect(menu).toHaveTextContent('assistants.tags.group_by')
|
||||
expect(menu).not.toHaveTextContent('assistants.tags.ungroup')
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'assistants.tags.group_by' })[0])
|
||||
expect(preferenceMocks.setSortType).toHaveBeenCalledWith('tags')
|
||||
})
|
||||
|
||||
it('offers turning tag grouping off when already grouping (tags → list)', () => {
|
||||
preferenceMocks.sortType = 'tags'
|
||||
|
||||
render(
|
||||
<AssistantResourceList activeAssistantId="assistant-1" onSelectTopic={vi.fn()} onStartDraftAssistant={vi.fn()} />
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('assistant-1-context-menu')).toHaveTextContent('assistants.tags.ungroup')
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'assistants.tags.ungroup' })[0])
|
||||
expect(preferenceMocks.setSortType).toHaveBeenCalledWith('list')
|
||||
})
|
||||
|
||||
it('uses delete-agent actions for the classic layout agent context and more menus', async () => {
|
||||
const onStartMissingAgentDraft = vi.fn()
|
||||
const onActiveAgentDeleted = vi.fn()
|
||||
|
||||
@@ -65,7 +65,8 @@ vi.mock('@renderer/components/VirtualList', () => {
|
||||
return rows
|
||||
}
|
||||
|
||||
const GroupedVirtualList = ({
|
||||
const GroupedVirtualListContent = ({
|
||||
dragEnabled,
|
||||
ref,
|
||||
className,
|
||||
groups,
|
||||
@@ -91,6 +92,7 @@ vi.mock('@renderer/components/VirtualList', () => {
|
||||
}}
|
||||
role={role}
|
||||
className={className}
|
||||
data-draggable={dragEnabled ? 'true' : 'false'}
|
||||
{...scrollerProps}>
|
||||
{rows.map((row, index) => {
|
||||
if (row.type === 'group-header') {
|
||||
@@ -114,8 +116,8 @@ vi.mock('@renderer/components/VirtualList', () => {
|
||||
return {
|
||||
buildGroupedVirtualRows,
|
||||
DynamicVirtualList: () => null,
|
||||
GroupedSortableVirtualList: GroupedVirtualList,
|
||||
GroupedVirtualList
|
||||
GroupedSortableVirtualList: (props) => <GroupedVirtualListContent {...props} dragEnabled />,
|
||||
GroupedVirtualList: (props) => <GroupedVirtualListContent {...props} dragEnabled={false} />
|
||||
}
|
||||
})
|
||||
|
||||
@@ -261,6 +263,88 @@ describe('ResourceEntityRail', () => {
|
||||
expect(screen.getByTestId('assistant-a-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('groups non-pinned entities into per-tag sections while keeping pinned on top', () => {
|
||||
render(
|
||||
<ResourceEntityRail
|
||||
addLabel="New"
|
||||
ariaLabel="Assistants list"
|
||||
defaultGroupLabel="Assistants"
|
||||
groupByTag
|
||||
items={[
|
||||
{ id: 'pinned-tagged', name: 'Pinned Tagged', icon: <span />, pinned: true, tag: 'work' },
|
||||
{ id: 'work-a', name: 'Work A', icon: <span data-testid="work-a-icon" />, tag: 'work' },
|
||||
{ id: 'home-a', name: 'Home A', icon: <span />, tag: 'home' },
|
||||
{ id: 'loose', name: 'Loose', icon: <span />, tag: undefined }
|
||||
]}
|
||||
variant="assistant"
|
||||
onAdd={vi.fn()}
|
||||
onReorder={vi.fn()}
|
||||
onSelect={vi.fn()}
|
||||
/>
|
||||
)
|
||||
|
||||
// Pinned section stays on top; non-pinned entities split into tag sections + an untagged section.
|
||||
expect(screen.getByText('selector.common.pinned_title')).toBeInTheDocument()
|
||||
expect(screen.getByText('work')).toBeInTheDocument()
|
||||
expect(screen.getByText('home')).toBeInTheDocument()
|
||||
expect(screen.getByText('assistants.tags.untagged')).toBeInTheDocument()
|
||||
expect(
|
||||
Array.from(
|
||||
screen.getByRole('listbox', { name: 'Assistants list' }).querySelectorAll('button[aria-expanded]')
|
||||
).map((header) => header.textContent)
|
||||
).toEqual(['selector.common.pinned_title', 'assistants.tags.untagged', 'work', 'home'])
|
||||
// A pinned entity stays under the pinned section even though it carries a tag — its tag must not
|
||||
// spawn a second "work" header.
|
||||
expect(screen.getAllByText('work')).toHaveLength(1)
|
||||
// The flat default "Assistants" header never appears while grouping by tag.
|
||||
expect(screen.queryByText('Assistants')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('work-a-icon')).toBeInTheDocument()
|
||||
expect(screen.getByRole('listbox', { name: 'Assistants list' })).toHaveAttribute('data-draggable', 'false')
|
||||
})
|
||||
|
||||
it('keeps a real tag named like the untagged sentinel separate from untagged entities', () => {
|
||||
render(
|
||||
<ResourceEntityRail
|
||||
addLabel="New"
|
||||
ariaLabel="Assistants list"
|
||||
defaultGroupLabel="Assistants"
|
||||
groupByTag
|
||||
items={[
|
||||
{ id: 'sentinel-tagged', name: 'Sentinel Tagged', icon: <span />, tag: '__untagged__' },
|
||||
{ id: 'loose', name: 'Loose', icon: <span />, tag: undefined }
|
||||
]}
|
||||
variant="assistant"
|
||||
onAdd={vi.fn()}
|
||||
onSelect={vi.fn()}
|
||||
/>
|
||||
)
|
||||
|
||||
expect(screen.getByText('__untagged__')).toBeInTheDocument()
|
||||
expect(screen.getByText('assistants.tags.untagged')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('ignores entity tags when groupByTag is off', () => {
|
||||
render(
|
||||
<ResourceEntityRail
|
||||
addLabel="New"
|
||||
ariaLabel="Assistants list"
|
||||
defaultGroupLabel="Assistants"
|
||||
items={[
|
||||
{ id: 'work-a', name: 'Work A', icon: <span />, tag: 'work' },
|
||||
{ id: 'home-a', name: 'Home A', icon: <span />, tag: 'home' }
|
||||
]}
|
||||
variant="assistant"
|
||||
onAdd={vi.fn()}
|
||||
onReorder={vi.fn()}
|
||||
onSelect={vi.fn()}
|
||||
/>
|
||||
)
|
||||
|
||||
expect(screen.queryByText('work')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('home')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('assistants.tags.untagged')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders a flat list with no section header when nothing is pinned', () => {
|
||||
render(
|
||||
<ResourceEntityRail
|
||||
|
||||
@@ -75,7 +75,7 @@ import {
|
||||
ComposerToolMenuControls
|
||||
} from './shared/ComposerControlScaffolding'
|
||||
import { type AddNewTopicPayload, emptyActions, type ProviderActionHandlers } from './shared/composerProviderActions'
|
||||
import { buildComposerQueuedPayload } from './shared/composerQueuedPayload'
|
||||
import { buildComposerQueuedPayload, hasUnsyncedComposerAttachments } from './shared/composerQueuedPayload'
|
||||
import { useComposerQuoteInsertion } from './shared/composerQuote'
|
||||
import { useComposerFileCapabilities } from './shared/useComposerFileCapabilities'
|
||||
import { useLatest } from './shared/useLatest'
|
||||
@@ -798,6 +798,8 @@ const ChatComposerInner = ({
|
||||
async (draft: ComposerSerializedDraft) => {
|
||||
const tokenIds = getComposerTokenIds(draft.tokens)
|
||||
const payloadFiles = files.filter((file) => tokenIds.has(chatComposerTokenId.file(file)))
|
||||
if (hasUnsyncedComposerAttachments(files, payloadFiles)) return null
|
||||
|
||||
const originalFilePartsByTokenId = editingOriginalFilePartsByTokenIdRef.current
|
||||
|
||||
const newFiles = payloadFiles.filter((file) => !originalFilePartsByTokenId.has(chatComposerTokenId.file(file)))
|
||||
@@ -840,6 +842,8 @@ const ChatComposerInner = ({
|
||||
}
|
||||
|
||||
const editedParts = await buildEditedMessageParts(draft)
|
||||
if (!editedParts) return
|
||||
|
||||
try {
|
||||
await chatWrite.forkAndResend(editingMessageForCurrentTopic.message.id, editedParts)
|
||||
restoreSavedDraft()
|
||||
|
||||
@@ -1256,6 +1256,46 @@ describe('AgentComposer', () => {
|
||||
expect(mocks.setFiles).toHaveBeenLastCalledWith([])
|
||||
})
|
||||
|
||||
it('does not send while only some attached file tokens are reflected in the editor', async () => {
|
||||
const secondFile = {
|
||||
id: 'file-2',
|
||||
fileTokenSourceId: 'source-file-2',
|
||||
name: 'summary.md',
|
||||
origin_name: 'summary.md',
|
||||
path: '/tmp/summary.md'
|
||||
} as FileMetadata
|
||||
mocks.files = [file, secondFile]
|
||||
mocks.draftTokens = [
|
||||
{
|
||||
id: `file:${file.fileTokenSourceId}`,
|
||||
kind: 'file',
|
||||
label: file.name,
|
||||
payload: file,
|
||||
index: 0,
|
||||
textOffset: mocks.draftText.length
|
||||
} as ComposerSerializedToken
|
||||
]
|
||||
|
||||
render(
|
||||
<AgentComposer
|
||||
agentId="agent-1"
|
||||
sessionId="session-1"
|
||||
sendMessage={mocks.sendMessage}
|
||||
stop={mocks.stop}
|
||||
isStreaming={false}
|
||||
/>
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText('send'))
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
})
|
||||
|
||||
expect(mocks.sendMessage).not.toHaveBeenCalled()
|
||||
expect(mocks.setFiles).not.toHaveBeenCalledWith([])
|
||||
expect(mocks.files).toEqual([file, secondFile])
|
||||
})
|
||||
|
||||
it('blocks sends while the parent session is switching', () => {
|
||||
render(
|
||||
<AgentComposer
|
||||
@@ -1340,13 +1380,20 @@ describe('AgentComposer', () => {
|
||||
it('restores the current draft, files, and skill tokens when sending a new agent message fails', async () => {
|
||||
mocks.availableSkills = [pdfSkill]
|
||||
mocks.draftText = 'draft message'
|
||||
mocks.draftTokens = [
|
||||
{
|
||||
...pdfSkillToken,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
}
|
||||
]
|
||||
const skillToken = {
|
||||
...pdfSkillToken,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
}
|
||||
const fileToken = {
|
||||
id: `file:${file.fileTokenSourceId}`,
|
||||
kind: 'file',
|
||||
label: file.name,
|
||||
payload: file,
|
||||
index: 1,
|
||||
textOffset: mocks.draftText.length
|
||||
} as ComposerSerializedToken
|
||||
mocks.draftTokens = [skillToken, fileToken]
|
||||
mocks.files = [file]
|
||||
mocks.sendMessage.mockRejectedValueOnce(new Error('send failed'))
|
||||
|
||||
@@ -1365,7 +1412,7 @@ describe('AgentComposer', () => {
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mocks.surfaceProps?.draftTokens).toEqual(mocks.draftTokens)
|
||||
expect(mocks.surfaceProps?.draftTokens).toEqual([skillToken])
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('send'))
|
||||
@@ -1378,24 +1425,12 @@ describe('AgentComposer', () => {
|
||||
expect(mocks.setFiles).toHaveBeenCalledWith([])
|
||||
expect(mocks.setFiles).toHaveBeenLastCalledWith([file])
|
||||
expect(mocks.surfaceProps?.text).toBe('draft message')
|
||||
expect(mocks.surfaceProps?.draftTokens).toEqual([
|
||||
{
|
||||
...pdfSkillToken,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
}
|
||||
])
|
||||
expect(mocks.surfaceProps?.draftTokens).toEqual([skillToken])
|
||||
expect(cacheService.setCasual).toHaveBeenLastCalledWith(
|
||||
'agent-session-draft-agent-1',
|
||||
{
|
||||
text: 'draft message',
|
||||
tokens: [
|
||||
{
|
||||
...pdfSkillToken,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
}
|
||||
]
|
||||
tokens: [skillToken]
|
||||
},
|
||||
86400000
|
||||
)
|
||||
|
||||
@@ -1097,6 +1097,64 @@ describe('ChatComposer', () => {
|
||||
expect(mocks.surfaceProps?.sendDisabled).toBe(false)
|
||||
})
|
||||
|
||||
it('does not submit a file-only draft before the file token is reflected in the editor', async () => {
|
||||
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
|
||||
const onSend = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
render(<ChatComposer topic={topic} onSend={onSend} />)
|
||||
|
||||
await act(async () => {
|
||||
await mocks.surfaceProps?.onSendDraft({ text: '', tokens: [] })
|
||||
})
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
|
||||
})
|
||||
|
||||
it('does not submit a text draft before a newly attached file token is reflected in the editor', async () => {
|
||||
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
|
||||
const onSend = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
render(<ChatComposer topic={topic} onSend={onSend} />)
|
||||
|
||||
await act(async () => {
|
||||
await mocks.surfaceProps?.onSendDraft({ text: 'summarize this', tokens: [] })
|
||||
})
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mocks.files).toHaveLength(1)
|
||||
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
|
||||
})
|
||||
|
||||
it('does not submit a text draft while only some attached file tokens are reflected in the editor', async () => {
|
||||
const syncedFile = { fileTokenSourceId: 'src-1', name: 'first.pdf', path: '/tmp/first.pdf' } as any
|
||||
const unsyncedFile = { fileTokenSourceId: 'src-2', name: 'second.pdf', path: '/tmp/second.pdf' } as any
|
||||
mocks.files = [syncedFile, unsyncedFile]
|
||||
const onSend = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
render(<ChatComposer topic={topic} onSend={onSend} />)
|
||||
|
||||
await act(async () => {
|
||||
await mocks.surfaceProps?.onSendDraft({
|
||||
text: 'summarize these',
|
||||
tokens: [
|
||||
{
|
||||
id: 'file:src-1',
|
||||
kind: 'file',
|
||||
label: 'first.pdf',
|
||||
payload: syncedFile,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
} as ComposerSerializedToken
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
expect(onSend).not.toHaveBeenCalled()
|
||||
expect(mocks.files).toEqual([syncedFile, unsyncedFile])
|
||||
expect(mocks.toastError).not.toHaveBeenCalledWith('chat.input.send_failed')
|
||||
})
|
||||
|
||||
it('keeps a steered follow-up in the dock and toasts when its manual send fails', async () => {
|
||||
mocks.topicPending = true
|
||||
const onSend = vi.fn().mockResolvedValue(undefined)
|
||||
@@ -2293,6 +2351,98 @@ describe('ChatComposer', () => {
|
||||
await waitFor(() => expect(mocks.surfaceProps?.editingState).toBeUndefined())
|
||||
})
|
||||
|
||||
it('does not fork and resend an edited file-only draft before the file token is reflected in the editor', async () => {
|
||||
const editMessage = vi.fn().mockResolvedValue(undefined)
|
||||
const resend = vi.fn().mockResolvedValue(undefined)
|
||||
const forkAndResend = vi.fn().mockResolvedValue(undefined)
|
||||
mocks.chatWrite = { pause: vi.fn(), editMessage, resend, forkAndResend }
|
||||
const message = {
|
||||
id: 'message-1',
|
||||
role: 'user',
|
||||
topicId: topic.id,
|
||||
createdAt: '2026-01-01T00:00:00.000Z',
|
||||
status: 'success'
|
||||
} as const
|
||||
|
||||
render(
|
||||
<MessageEditingProvider>
|
||||
<StartEditingOnMount message={message as any} parts={[{ type: 'text', text: 'old' }] as any} />
|
||||
<ChatComposer topic={topic} onSend={vi.fn()} />
|
||||
</MessageEditingProvider>
|
||||
)
|
||||
|
||||
await waitFor(() => expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1'))
|
||||
|
||||
act(() => {
|
||||
mocks.files = [{ fileTokenSourceId: 'src-1', name: 'doc.pdf', path: '/tmp/doc.pdf' } as any]
|
||||
mocks.surfaceProps?.onTextChange('')
|
||||
})
|
||||
await waitFor(() => expect(mocks.surfaceProps?.text).toBe(''))
|
||||
|
||||
await act(async () => {
|
||||
await mocks.surfaceProps?.onSendDraft({ text: '', tokens: [] })
|
||||
})
|
||||
|
||||
expect(forkAndResend).not.toHaveBeenCalled()
|
||||
expect(editMessage).not.toHaveBeenCalled()
|
||||
expect(resend).not.toHaveBeenCalled()
|
||||
expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1')
|
||||
expect(mocks.toastError).not.toHaveBeenCalledWith('message.error.operation_unavailable')
|
||||
})
|
||||
|
||||
it('does not fork and resend an edited draft while only some attached file tokens are reflected in the editor', async () => {
|
||||
const editMessage = vi.fn().mockResolvedValue(undefined)
|
||||
const resend = vi.fn().mockResolvedValue(undefined)
|
||||
const forkAndResend = vi.fn().mockResolvedValue(undefined)
|
||||
mocks.chatWrite = { pause: vi.fn(), editMessage, resend, forkAndResend }
|
||||
const message = {
|
||||
id: 'message-1',
|
||||
role: 'user',
|
||||
topicId: topic.id,
|
||||
createdAt: '2026-01-01T00:00:00.000Z',
|
||||
status: 'success'
|
||||
} as const
|
||||
const syncedFile = { fileTokenSourceId: 'src-1', name: 'first.pdf', path: '/tmp/first.pdf' } as any
|
||||
const unsyncedFile = { fileTokenSourceId: 'src-2', name: 'second.pdf', path: '/tmp/second.pdf' } as any
|
||||
|
||||
render(
|
||||
<MessageEditingProvider>
|
||||
<StartEditingOnMount message={message as any} parts={[{ type: 'text', text: 'old' }] as any} />
|
||||
<ChatComposer topic={topic} onSend={vi.fn()} />
|
||||
</MessageEditingProvider>
|
||||
)
|
||||
|
||||
await waitFor(() => expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1'))
|
||||
|
||||
act(() => {
|
||||
mocks.files = [syncedFile, unsyncedFile]
|
||||
mocks.surfaceProps?.onTextChange('')
|
||||
})
|
||||
await waitFor(() => expect(mocks.surfaceProps?.text).toBe(''))
|
||||
|
||||
await act(async () => {
|
||||
await mocks.surfaceProps?.onSendDraft({
|
||||
text: '',
|
||||
tokens: [
|
||||
{
|
||||
id: 'file:src-1',
|
||||
kind: 'file',
|
||||
label: 'first.pdf',
|
||||
payload: syncedFile,
|
||||
index: 0,
|
||||
textOffset: 0
|
||||
} as ComposerSerializedToken
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
expect(forkAndResend).not.toHaveBeenCalled()
|
||||
expect(editMessage).not.toHaveBeenCalled()
|
||||
expect(resend).not.toHaveBeenCalled()
|
||||
expect(mocks.surfaceProps?.editingState?.messageId).toBe('message-1')
|
||||
expect(mocks.toastError).not.toHaveBeenCalledWith('message.error.operation_unavailable')
|
||||
})
|
||||
|
||||
it('keeps editing when the edited message fork and resend fails', async () => {
|
||||
const editMessage = vi.fn().mockResolvedValue(undefined)
|
||||
const resend = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
@@ -39,17 +39,42 @@ describe('buildComposerQueuedPayload', () => {
|
||||
expect(result?.userMessageParts).toEqual([{ type: 'text', text: '' }])
|
||||
})
|
||||
|
||||
it('attaches only files still present as draft tokens', () => {
|
||||
const kept = file('a')
|
||||
const removed = file('b')
|
||||
it('returns null for a file-only draft whose file token has not reached the editor draft yet', () => {
|
||||
const result = buildComposerQueuedPayload(draft('', []), { files: [file('a')], fileTokenId })
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null for a text draft whose file token has not reached the editor draft yet', () => {
|
||||
const result = buildComposerQueuedPayload(draft('summarize this', []), { files: [file('a')], fileTokenId })
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when only some current file tokens have reached the editor draft', () => {
|
||||
const synced = file('a')
|
||||
const unsynced = file('b')
|
||||
|
||||
const result = buildComposerQueuedPayload(draft('hi', ['file:a']), {
|
||||
files: [kept, removed],
|
||||
files: [synced, unsynced],
|
||||
fileTokenId,
|
||||
requireText: true
|
||||
})
|
||||
|
||||
expect(result?.attachments).toEqual([kept])
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('attaches files when every current file is present as a draft token', () => {
|
||||
const first = file('a')
|
||||
const second = file('b')
|
||||
|
||||
const result = buildComposerQueuedPayload(draft('hi', ['file:a', 'file:b']), {
|
||||
files: [first, second],
|
||||
fileTokenId,
|
||||
requireText: true
|
||||
})
|
||||
|
||||
expect(result?.attachments).toEqual([first, second])
|
||||
expect(result?.userMessageParts).toEqual([{ type: 'text', text: 'hi' }])
|
||||
})
|
||||
|
||||
|
||||
@@ -31,10 +31,11 @@ export function buildComposerQueuedPayload(
|
||||
{ files, fileTokenId, requireText = false, extra }: BuildComposerQueuedPayloadOptions
|
||||
): ComposerQueuedMessagePayload | null {
|
||||
const text = draft.text.trim()
|
||||
if (requireText ? !text : !text && files.length === 0) return null
|
||||
|
||||
const tokenIds = getComposerTokenIds(draft.tokens)
|
||||
const attachedFiles = files.filter((file) => tokenIds.has(fileTokenId(file)))
|
||||
if (hasUnsyncedComposerAttachments(files, attachedFiles)) return null
|
||||
if (requireText ? !text : !text && attachedFiles.length === 0) return null
|
||||
|
||||
const userMessageParts = createComposerUserMessageParts(draft)
|
||||
|
||||
return {
|
||||
@@ -44,3 +45,7 @@ export function buildComposerQueuedPayload(
|
||||
...extra?.(tokenIds, attachedFiles)
|
||||
}
|
||||
}
|
||||
|
||||
export function hasUnsyncedComposerAttachments(files: ComposerAttachment[], attachedFiles: ComposerAttachment[]) {
|
||||
return attachedFiles.length !== files.length
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@ const AssistantEditDialog = lazy(() =>
|
||||
/**
|
||||
* Row shape the selector operates on — derived from the Assistant DTO. `selectionType: 'item'`
|
||||
* returns values of this shape (not the raw Assistant) so the selector never leaks DB columns the
|
||||
* caller didn't ask about. User tag names may be present so the selector can filter by assistant
|
||||
* tags.
|
||||
* caller didn't ask about. A user tag name may be present so the selector can filter by assistant
|
||||
* tag.
|
||||
*/
|
||||
export type AssistantSelectorItem = ResourceSelectorShellItem
|
||||
|
||||
@@ -118,7 +118,7 @@ export function AssistantSelector(props: AssistantSelectorProps) {
|
||||
name: a.name,
|
||||
emoji: a.emoji,
|
||||
description: a.description,
|
||||
tags: (a.tags ?? []).map((tag) => tag.name)
|
||||
tag: a.tags?.[0]?.name
|
||||
})),
|
||||
[data]
|
||||
)
|
||||
@@ -126,7 +126,8 @@ export function AssistantSelector(props: AssistantSelectorProps) {
|
||||
const tags = useMemo<ResourceSelectorShellTag[]>(() => {
|
||||
const byName = new Map<string, string | undefined>()
|
||||
for (const assistant of data?.items ?? []) {
|
||||
for (const tag of assistant.tags ?? []) {
|
||||
const tag = assistant.tags?.[0]
|
||||
if (tag) {
|
||||
if (!byName.has(tag.name)) {
|
||||
byName.set(tag.name, tag.color ?? undefined)
|
||||
}
|
||||
@@ -200,7 +201,7 @@ export function AssistantSelector(props: AssistantSelectorProps) {
|
||||
name: created.name,
|
||||
emoji: created.emoji,
|
||||
description: created.description,
|
||||
tags: (created.tags ?? []).map((tag) => tag.name)
|
||||
tag: created.tags?.[0]?.name
|
||||
})
|
||||
} else {
|
||||
props.onChange(created.id)
|
||||
|
||||
@@ -30,7 +30,7 @@ export type ResourceSelectorShellItem = {
|
||||
name: string
|
||||
emoji?: string
|
||||
description?: string
|
||||
tags?: string[]
|
||||
tag?: string
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
|
||||
)
|
||||
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const [selectedTagIds, setSelectedTagIds] = useState<string[]>([])
|
||||
const [selectedTagName, setSelectedTagName] = useState<string | null>(null)
|
||||
const listboxId = useId()
|
||||
const listRef = useRef<HTMLDivElement>(null)
|
||||
const searchInputRef = useRef<HTMLInputElement>(null)
|
||||
@@ -320,9 +320,8 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
|
||||
|
||||
const { pinnedItems, unpinnedItems } = useMemo(() => {
|
||||
let filtered = items
|
||||
if (selectedTagIds.length > 0) {
|
||||
const wanted = new Set(selectedTagIds)
|
||||
filtered = filtered.filter((item) => item.tags?.some((tag) => wanted.has(tag)))
|
||||
if (selectedTagName) {
|
||||
filtered = filtered.filter((item) => item.tag === selectedTagName)
|
||||
}
|
||||
|
||||
const query = searchValue.trim().toLowerCase()
|
||||
@@ -338,7 +337,7 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
|
||||
const unpinned = filtered.filter((item) => !pinnedSet.has(item.id))
|
||||
const pinnedOrdered = pinnedIds.map((id) => pinned.find((item) => item.id === id)).filter(Boolean) as T[]
|
||||
return { pinnedItems: pinnedOrdered, unpinnedItems: unpinned }
|
||||
}, [items, pinnedIds, pinnedSet, searchValue, selectedTagIds])
|
||||
}, [items, pinnedIds, pinnedSet, searchValue, selectedTagName])
|
||||
|
||||
const sections = useMemo<ResourceSelectorSection<T>[]>(() => {
|
||||
const nextSections: ResourceSelectorSection<T>[] = []
|
||||
@@ -569,18 +568,14 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
|
||||
<>
|
||||
{labels.tagFilter ? <span className="mr-1 text-[10px] text-muted-foreground">{labels.tagFilter}</span> : null}
|
||||
{tagOptions.map((tag) => {
|
||||
const active = selectedTagIds.includes(tag.name)
|
||||
const active = selectedTagName === tag.name
|
||||
return (
|
||||
<ResourceTagChip
|
||||
key={tag.name}
|
||||
tag={tag.name}
|
||||
color={tag.color}
|
||||
active={active}
|
||||
onClick={() =>
|
||||
setSelectedTagIds((prev) =>
|
||||
prev.includes(tag.name) ? prev.filter((value) => value !== tag.name) : [...prev, tag.name]
|
||||
)
|
||||
}
|
||||
onClick={() => setSelectedTagName((prev) => (prev === tag.name ? null : tag.name))}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
@@ -607,16 +602,13 @@ export function ResourceSelectorShell<T extends ResourceSelectorShellItem>(props
|
||||
<span className="flex size-5 shrink-0 items-center justify-center">{fallbackIcon}</span>
|
||||
) : null
|
||||
|
||||
const trailing =
|
||||
item.tags && item.tags.length > 0 ? (
|
||||
<div
|
||||
className="ml-2 flex h-4 max-w-[48%] shrink-0 items-center justify-end gap-1 overflow-hidden"
|
||||
data-resource-selector-tags={item.id}>
|
||||
{item.tags.map((tag) => (
|
||||
<ResourceTagChip key={`${item.id}-${tag}`} tag={tag} color={tagColorByName.get(tag)} />
|
||||
))}
|
||||
</div>
|
||||
) : null
|
||||
const trailing = item.tag ? (
|
||||
<div
|
||||
className="ml-2 flex h-4 max-w-[48%] shrink-0 items-center justify-end gap-1 overflow-hidden"
|
||||
data-resource-selector-tags={item.id}>
|
||||
<ResourceTagChip tag={item.tag} color={tagColorByName.get(item.tag)} />
|
||||
</div>
|
||||
) : null
|
||||
|
||||
return (
|
||||
<div key={item.id} className="py-0.5">
|
||||
|
||||
@@ -92,7 +92,7 @@ vi.mock('react-i18next', async (importOriginal) => {
|
||||
'library.config.basic.model_pick': 'Pick model',
|
||||
'library.config.basic.model_not_found': 'Model {{id}} is unavailable.',
|
||||
'library.config.basic.tag_empty': 'No tags',
|
||||
'library.config.basic.tag_placeholder': 'Select tags',
|
||||
'library.config.basic.tag_placeholder': 'Select tag',
|
||||
'library.config.basic.tag_search': 'Search tags',
|
||||
'library.config.prompt.label': 'Prompt',
|
||||
'library.config.prompt.placeholder': 'Tell this assistant how to respond',
|
||||
|
||||
@@ -636,7 +636,7 @@ describe('ResourceSelectorShell', () => {
|
||||
|
||||
describe('edit button', () => {
|
||||
it('places edit and pin together in the row action area', () => {
|
||||
const taggedItems: Item[] = [{ ...ITEMS[0], tags: ['Cherry', 'DEV'] }, ...ITEMS.slice(1)]
|
||||
const taggedItems: Item[] = [{ ...ITEMS[0], tag: 'Cherry' }, ...ITEMS.slice(1)]
|
||||
|
||||
render(
|
||||
<ResourceSelectorShell
|
||||
@@ -891,5 +891,35 @@ describe('ResourceSelectorShell', () => {
|
||||
expect(screen.queryByRole('option', { name: /Alpha/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: /Beta/ })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('uses a single active tag filter at a time', () => {
|
||||
render(
|
||||
<ResourceSelectorShell
|
||||
trigger={<button type="button">Open</button>}
|
||||
items={[
|
||||
{ ...ITEMS[0], tag: 'Cherry' },
|
||||
{ ...ITEMS[1], tag: 'DEV' },
|
||||
{ ...ITEMS[2], tag: 'Cherry' }
|
||||
]}
|
||||
tags={['Cherry', 'DEV']}
|
||||
pinnedIds={[]}
|
||||
onTogglePin={vi.fn()}
|
||||
labels={LABELS}
|
||||
value={null}
|
||||
onChange={vi.fn()}
|
||||
/>
|
||||
)
|
||||
openPopover()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Cherry' }))
|
||||
expect(screen.queryByRole('option', { name: /Alpha/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: /Gamma/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: /Beta/ })).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'DEV' }))
|
||||
expect(screen.queryByRole('option', { name: /Beta/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: /Alpha/ })).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: /Gamma/ })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -295,8 +295,9 @@ vi.mock('react-i18next', async (importOriginal) => {
|
||||
'library.config.basic.model_not_found': 'Model {{id}} is unavailable.',
|
||||
'library.config.basic.precise': 'Precise',
|
||||
'library.config.basic.stream_output': 'Stream output',
|
||||
'library.config.basic.tags': 'Tags',
|
||||
'library.config.basic.tag_empty': 'No tags',
|
||||
'library.config.basic.tag_placeholder': 'Select tags',
|
||||
'library.config.basic.tag_placeholder': 'Select tag',
|
||||
'library.config.basic.tag_search': 'Search tags',
|
||||
'library.config.basic.mcp_mode': 'MCP Mode',
|
||||
'library.config.basic.temperature': 'Temperature',
|
||||
@@ -558,13 +559,10 @@ async function expectVariablesHelpOnHover() {
|
||||
await waitFor(() => expect(screen.getAllByText('{{date}}').length).toBeGreaterThan(0))
|
||||
}
|
||||
|
||||
function openTagCombobox() {
|
||||
const removeTagButton = screen.getByRole('button', { name: 'Remove work' })
|
||||
const combobox = removeTagButton.closest('[role="combobox"]')
|
||||
if (!combobox) {
|
||||
throw new Error('Tag combobox trigger not found')
|
||||
}
|
||||
fireEvent.click(combobox)
|
||||
function openTagSelect() {
|
||||
const select = screen.getByRole('combobox', { name: 'Tags' })
|
||||
fireEvent.pointerDown(select)
|
||||
fireEvent.click(select)
|
||||
}
|
||||
|
||||
describe('edit dialogs', () => {
|
||||
@@ -616,57 +614,45 @@ describe('edit dialogs', () => {
|
||||
})
|
||||
|
||||
it('submits assistant tag changes through ensureTags', async () => {
|
||||
ensureTagsMock.mockResolvedValueOnce([
|
||||
{ id: 'tag-work', name: 'work', color: '#8b5cf6' },
|
||||
{ id: 'tag-personal', name: 'personal', color: '#10b981' }
|
||||
])
|
||||
ensureTagsMock.mockResolvedValueOnce([{ id: 'tag-personal', name: 'personal', color: '#10b981' }])
|
||||
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
|
||||
|
||||
openTagCombobox()
|
||||
fireEvent.click(await screen.findByText('personal'))
|
||||
openTagSelect()
|
||||
fireEvent.click(await screen.findByRole('option', { name: 'personal' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['work', 'personal']))
|
||||
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['personal']))
|
||||
expect(updateAssistantMock).toHaveBeenCalledWith({
|
||||
body: expect.objectContaining({
|
||||
tagIds: ['tag-work', 'tag-personal']
|
||||
tagIds: ['tag-personal']
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('creates and binds a new tag typed in assistant editing', async () => {
|
||||
ensureTagsMock.mockResolvedValueOnce([
|
||||
{ id: 'tag-work', name: 'work', color: '#8b5cf6' },
|
||||
{ id: 'tag-new', name: 'new-tag', color: '#10b981' }
|
||||
])
|
||||
it('clears the assistant tag from the single-select tag field', async () => {
|
||||
ensureTagsMock.mockResolvedValueOnce([])
|
||||
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
|
||||
|
||||
openTagCombobox()
|
||||
fireEvent.change(screen.getByPlaceholderText('Search tags'), { target: { value: 'new-tag' } })
|
||||
fireEvent.click(await screen.findByText('new-tag'))
|
||||
const clearButton = screen.getByRole('button', { name: 'Tags Clear' })
|
||||
expect(clearButton).toHaveClass('focus-visible:pointer-events-auto', 'focus-visible:opacity-100')
|
||||
fireEvent.click(clearButton)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['work', 'new-tag']))
|
||||
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith([]))
|
||||
expect(updateAssistantMock).toHaveBeenCalledWith({
|
||||
body: expect.objectContaining({
|
||||
tagIds: ['tag-work', 'tag-new']
|
||||
tagIds: []
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('does not expose typed tag names beyond the server-side max length', async () => {
|
||||
it('limits assistant tag editing to existing tags', async () => {
|
||||
render(<AssistantEditDialog open resource={ASSISTANT} onOpenChange={vi.fn()} onSaved={vi.fn()} />)
|
||||
// Open the combobox popover (CommandInput only mounts once the trigger is active)
|
||||
openTagCombobox()
|
||||
const atLimit = 'y'.repeat(64) // TagNameSchema.max(64)
|
||||
const tooLong = 'x'.repeat(65) // one over the limit — server would reject
|
||||
const searchInput = screen.getByPlaceholderText('Search tags')
|
||||
|
||||
fireEvent.change(searchInput, { target: { value: tooLong } })
|
||||
expect(screen.queryByText(tooLong)).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.change(searchInput, { target: { value: atLimit } })
|
||||
expect(await screen.findByText(atLimit)).toBeInTheDocument()
|
||||
openTagSelect()
|
||||
expect(screen.queryByPlaceholderText('Search tags')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('option', { name: 'No tag' })).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('new-tag')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('submits agent instructions and model changes as a PATCH', async () => {
|
||||
|
||||
@@ -1,53 +1,80 @@
|
||||
import { Combobox, type ComboboxOption } from '@cherrystudio/ui'
|
||||
import { TagNameSchema } from '@shared/data/types/tag'
|
||||
import { Button, Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@cherrystudio/ui'
|
||||
import { cn } from '@renderer/utils/style'
|
||||
import { X } from 'lucide-react'
|
||||
import type { FC } from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
interface Props {
|
||||
value: string[]
|
||||
onChange: (tags: string[]) => void
|
||||
value: string | null
|
||||
onChange: (tag: string | null) => void
|
||||
allTagNames: string[]
|
||||
disabled?: boolean
|
||||
portalContainer?: HTMLElement | null
|
||||
}
|
||||
|
||||
const TAG_SELECT_VALUE_PREFIX = 'tag:'
|
||||
|
||||
function encodeTagSelectValue(name: string) {
|
||||
return `${TAG_SELECT_VALUE_PREFIX}${name}`
|
||||
}
|
||||
|
||||
function decodeTagSelectValue(value: string) {
|
||||
if (!value.startsWith(TAG_SELECT_VALUE_PREFIX)) return null
|
||||
return value.slice(TAG_SELECT_VALUE_PREFIX.length)
|
||||
}
|
||||
|
||||
export const TagSelector: FC<Props> = ({ value, onChange, allTagNames, disabled, portalContainer }) => {
|
||||
const { t } = useTranslation()
|
||||
const [search, setSearch] = useState('')
|
||||
|
||||
// `value` may contain names not present in `/tags` yet, for example while a
|
||||
// caller waits for SWR refresh. Keep selected names visible in the options.
|
||||
const tagOptions = useMemo<ComboboxOption[]>(() => {
|
||||
const trimmedSearch = search.trim()
|
||||
const names = new Set([...allTagNames, ...value])
|
||||
// Mirror the server-side TagNameSchema (z.string().trim().min(1).max(64))
|
||||
// so a user cannot select a name the create endpoint would reject.
|
||||
if (trimmedSearch && TagNameSchema.safeParse(trimmedSearch).success) {
|
||||
names.add(trimmedSearch)
|
||||
}
|
||||
// `value` may be a name not present in `/tags` yet, for example while a
|
||||
// caller waits for SWR refresh. Keep the selected name visible in the options.
|
||||
const tagNames = useMemo(() => {
|
||||
const names = new Set(allTagNames)
|
||||
if (value) names.add(value)
|
||||
|
||||
const sortedNames = Array.from(names)
|
||||
sortedNames.sort((a, b) => a.localeCompare(b, 'zh'))
|
||||
return sortedNames.map((name) => ({
|
||||
value: name,
|
||||
label: name
|
||||
}))
|
||||
}, [allTagNames, search, value])
|
||||
return sortedNames
|
||||
}, [allTagNames, value])
|
||||
|
||||
return (
|
||||
<Combobox
|
||||
multiple
|
||||
size="sm"
|
||||
disabled={disabled}
|
||||
options={tagOptions}
|
||||
value={value}
|
||||
onChange={(v) => onChange(Array.isArray(v) ? v : v ? [v] : [])}
|
||||
onSearch={setSearch}
|
||||
placeholder={t('library.config.basic.tag_placeholder')}
|
||||
searchPlaceholder={t('library.config.basic.tag_search')}
|
||||
emptyText={t('library.config.basic.tag_empty')}
|
||||
portalContainer={portalContainer ?? undefined}
|
||||
/>
|
||||
<div className="group/tag-select relative flex w-full min-w-0 items-center">
|
||||
<Select
|
||||
disabled={disabled}
|
||||
value={value ? encodeTagSelectValue(value) : ''}
|
||||
onValueChange={(selectedValue) => onChange(decodeTagSelectValue(selectedValue))}>
|
||||
<SelectTrigger
|
||||
size="sm"
|
||||
className={cn(
|
||||
'w-full',
|
||||
value &&
|
||||
'[&_svg]:transition-opacity group-focus-within/tag-select:[&_svg]:opacity-0 group-hover/tag-select:[&_svg]:opacity-0'
|
||||
)}
|
||||
aria-label={t('library.config.basic.tags')}>
|
||||
<SelectValue placeholder={t('library.config.basic.tag_placeholder')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent portalContainer={portalContainer ?? undefined}>
|
||||
{tagNames.map((name) => (
|
||||
<SelectItem key={name} value={encodeTagSelectValue(name)}>
|
||||
{name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{value && !disabled ? (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
aria-label={`${t('library.config.basic.tags')} ${t('common.clear')}`}
|
||||
onClick={(event) => {
|
||||
event.stopPropagation()
|
||||
onChange(null)
|
||||
}}
|
||||
className="-translate-y-1/2 pointer-events-none absolute top-1/2 right-2.5 flex size-5 min-h-0 shrink-0 items-center justify-center rounded-full bg-transparent p-0 text-muted-foreground/70 opacity-0 shadow-none transition-[background-color,color,opacity] hover:bg-muted hover:text-foreground focus-visible:pointer-events-auto focus-visible:opacity-100 active:bg-muted group-focus-within/tag-select:pointer-events-auto group-focus-within/tag-select:opacity-100 group-hover/tag-select:pointer-events-auto group-hover/tag-select:opacity-100">
|
||||
<X size={12} />
|
||||
</Button>
|
||||
) : null}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ type AssistantEditFormValues = {
|
||||
name: string
|
||||
description: string
|
||||
modelId: UniqueModelId | null
|
||||
tags: string[]
|
||||
tagName: string | null
|
||||
prompt: string
|
||||
temperature: number
|
||||
enableTemperature: boolean
|
||||
@@ -99,7 +99,7 @@ function defaultValuesForAssistant(resource: AssistantEditDialogResource): Assis
|
||||
name: form.name,
|
||||
description: form.description,
|
||||
modelId: form.modelId ?? null,
|
||||
tags: form.tags,
|
||||
tagName: form.tagName,
|
||||
prompt: form.prompt,
|
||||
temperature: form.temperature,
|
||||
enableTemperature: form.enableTemperature,
|
||||
@@ -132,7 +132,7 @@ function buildAssistantFormState(baseline: AssistantFormState, values: Assistant
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
modelId: values.modelId,
|
||||
tags: values.tags,
|
||||
tagName: values.tagName,
|
||||
prompt: values.prompt,
|
||||
temperature: values.temperature,
|
||||
enableTemperature: values.enableTemperature,
|
||||
@@ -361,7 +361,7 @@ function AssistantBasicFields({
|
||||
</div>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="tags"
|
||||
name="tagName"
|
||||
render={({ field }) => (
|
||||
<FormItem className="min-w-0">
|
||||
<FormLabel>{t('library.config.basic.tags')}</FormLabel>
|
||||
|
||||
@@ -69,9 +69,9 @@ describe('initialAssistantFormState', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('extracts tag names from embedded tag rows', () => {
|
||||
it('extracts a single tag name from embedded tag rows', () => {
|
||||
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00'), tag('t2', 'beta', '#0f0')] })
|
||||
expect(initialAssistantFormState(assistant).tags).toEqual(['alpha', 'beta'])
|
||||
expect(initialAssistantFormState(assistant).tagName).toBe('alpha')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -128,22 +128,24 @@ describe('diffAssistantUpdate', () => {
|
||||
expect(result?.dto.settings).toMatchObject({ reasoning_effort: 'high' })
|
||||
})
|
||||
|
||||
it('flags tag changes and passes form.tags through as tagNames', () => {
|
||||
it('flags tag changes and passes one form tag through as tagNames', () => {
|
||||
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00')] })
|
||||
const baseline = initialAssistantFormState(assistant)
|
||||
const form = { ...baseline, tags: ['alpha', 'new'] }
|
||||
const form = { ...baseline, tagName: 'new' }
|
||||
|
||||
const result = diffAssistantUpdate(form, baseline, assistant)
|
||||
expect(result?.tagsChanged).toBe(true)
|
||||
expect(result?.tagNames).toEqual(['alpha', 'new'])
|
||||
expect(result?.tagNames).toEqual(['new'])
|
||||
})
|
||||
|
||||
it('treats tag reorder (same set) as unchanged', () => {
|
||||
it('flags clearing the assistant tag', () => {
|
||||
const assistant = createAssistant({ tags: [tag('t1', 'alpha', '#f00'), tag('t2', 'beta', '#0f0')] })
|
||||
const baseline = initialAssistantFormState(assistant)
|
||||
const form = { ...baseline, tags: ['beta', 'alpha'] }
|
||||
const form = { ...baseline, tagName: null }
|
||||
|
||||
expect(diffAssistantUpdate(form, baseline, assistant)).toBeNull()
|
||||
const result = diffAssistantUpdate(form, baseline, assistant)
|
||||
expect(result?.tagsChanged).toBe(true)
|
||||
expect(result?.tagNames).toEqual([])
|
||||
})
|
||||
|
||||
it('emits knowledgeBaseIds only when the set changes, ignoring order', () => {
|
||||
@@ -187,12 +189,12 @@ describe('diffAssistantSaveIntent', () => {
|
||||
it('wraps update diffs for the edit dialog save handler', () => {
|
||||
const assistant = createAssistant({ tags: [tag('t1', 'alpha')] })
|
||||
const baseline = initialAssistantFormState(assistant)
|
||||
const form = { ...baseline, tags: ['alpha', 'beta'] }
|
||||
const form = { ...baseline, tagName: 'beta' }
|
||||
|
||||
expect(diffAssistantSaveIntent(form, baseline, assistant)).toEqual({
|
||||
kind: 'update',
|
||||
payload: {},
|
||||
tagNames: ['alpha', 'beta'],
|
||||
tagNames: ['beta'],
|
||||
tagsChanged: true
|
||||
})
|
||||
})
|
||||
|
||||
@@ -21,9 +21,9 @@ const UI_DEFAULT_MAX_TOOL_CALLS = 20
|
||||
* Flat form state for the Assistant edit dialog. Every editable field lives
|
||||
* here so the dialog commits in a single PATCH.
|
||||
*
|
||||
* `tags` stores user-facing names, not ids — tag-id resolution happens
|
||||
* at save time via `ensureTags` so the user can freely type new tags without
|
||||
* paying a network round-trip per keystroke.
|
||||
* `tagName` stores one user-facing name, not an id — tag-id resolution happens
|
||||
* at save time via `ensureTags`, keeping the form state independent from
|
||||
* backend tag ids.
|
||||
*/
|
||||
export interface AssistantFormState {
|
||||
// columns
|
||||
@@ -46,11 +46,15 @@ export interface AssistantFormState {
|
||||
customParameters: CustomParameter[]
|
||||
mcpMode: AssistantSettings['mcpMode']
|
||||
// relations
|
||||
tags: string[]
|
||||
tagName: string | null
|
||||
knowledgeBaseIds: string[]
|
||||
mcpServerIds: string[]
|
||||
}
|
||||
|
||||
function normalizeAssistantTagName(tags: readonly string[]): string | null {
|
||||
return tags[0] ?? null
|
||||
}
|
||||
|
||||
function buildAssistantSettingsFromForm(
|
||||
form: AssistantFormState,
|
||||
baseSettings: AssistantSettings = DEFAULT_ASSISTANT_SETTINGS
|
||||
@@ -90,7 +94,7 @@ export function initialAssistantFormState(assistant: Assistant): AssistantFormSt
|
||||
enableMaxToolCalls: settings.enableMaxToolCalls ?? true,
|
||||
customParameters: settings.customParameters ?? [],
|
||||
mcpMode: settings.mcpMode ?? 'auto',
|
||||
tags: (assistant.tags ?? []).map((t) => t.name),
|
||||
tagName: normalizeAssistantTagName((assistant.tags ?? []).map((t) => t.name)),
|
||||
knowledgeBaseIds: assistant.knowledgeBaseIds ?? [],
|
||||
mcpServerIds: assistant.mcpServerIds ?? []
|
||||
}
|
||||
@@ -160,7 +164,7 @@ export function diffAssistantUpdate(
|
||||
baseline.mcpMode !== form.mcpMode ||
|
||||
customParametersChanged
|
||||
|
||||
const tagsChanged = !sameStringSet(baseline.tags, form.tags)
|
||||
const tagsChanged = baseline.tagName !== form.tagName
|
||||
const knowledgeBaseIdsChanged = !sameIdSet(baseline.knowledgeBaseIds, form.knowledgeBaseIds)
|
||||
const mcpServerIdsChanged = !sameIdSet(baseline.mcpServerIds, form.mcpServerIds)
|
||||
|
||||
@@ -183,7 +187,7 @@ export function diffAssistantUpdate(
|
||||
...(mcpServerIdsChanged ? { mcpServerIds: form.mcpServerIds } : {})
|
||||
}
|
||||
|
||||
return { dto, tagsChanged, tagNames: form.tags }
|
||||
return { dto, tagsChanged, tagNames: form.tagName ? [form.tagName] : [] }
|
||||
}
|
||||
|
||||
export function diffAssistantSaveIntent(
|
||||
@@ -208,9 +212,3 @@ function sameIdSet(a: readonly string[], b: readonly string[]): boolean {
|
||||
const set = new Set(a)
|
||||
return b.every((id) => set.has(id))
|
||||
}
|
||||
|
||||
function sameStringSet(a: readonly string[], b: readonly string[]): boolean {
|
||||
if (a.length !== b.length) return false
|
||||
const set = new Set(a)
|
||||
return b.every((v) => set.has(v))
|
||||
}
|
||||
|
||||
@@ -1333,12 +1333,14 @@
|
||||
"add": "Add Tag",
|
||||
"delete": "Delete Tag",
|
||||
"deleteConfirm": "Are you sure to delete this tag?",
|
||||
"group_by": "Group by tag",
|
||||
"manage": "Tag Management",
|
||||
"modify": "Modify Tag",
|
||||
"none": "No tags",
|
||||
"settings": {
|
||||
"title": "Tag Settings"
|
||||
},
|
||||
"ungroup": "Turn off tag grouping",
|
||||
"untagged": "Untagged"
|
||||
},
|
||||
"title": "Assistants",
|
||||
@@ -3154,9 +3156,9 @@
|
||||
"stream_output": "Stream output",
|
||||
"tag_empty": "No tags available",
|
||||
"tag_hint": "To add a new tag, use the \"+ Tag\" entry in the library top bar",
|
||||
"tag_placeholder": "Select tags",
|
||||
"tag_placeholder": "Select tag",
|
||||
"tag_search": "Search tags",
|
||||
"tags": "Tags",
|
||||
"tags": "Tag",
|
||||
"temperature": "Temperature",
|
||||
"title": "Basic settings",
|
||||
"top_p": "Top-P",
|
||||
|
||||
@@ -1333,12 +1333,14 @@
|
||||
"add": "添加标签",
|
||||
"delete": "删除标签",
|
||||
"deleteConfirm": "确定要删除这个标签吗?",
|
||||
"group_by": "按照标签分组",
|
||||
"manage": "标签管理",
|
||||
"modify": "修改标签",
|
||||
"none": "暂无标签",
|
||||
"settings": {
|
||||
"title": "标签设置"
|
||||
},
|
||||
"ungroup": "关闭标签分组",
|
||||
"untagged": "未分组"
|
||||
},
|
||||
"title": "助手",
|
||||
|
||||
@@ -1333,12 +1333,14 @@
|
||||
"add": "新增標籤",
|
||||
"delete": "刪除標籤",
|
||||
"deleteConfirm": "確定要刪除這個標籤嗎?",
|
||||
"group_by": "依標籤分組",
|
||||
"manage": "標籤管理",
|
||||
"modify": "修改標籤",
|
||||
"none": "暫無標籤",
|
||||
"settings": {
|
||||
"title": "標籤設定"
|
||||
},
|
||||
"ungroup": "關閉標籤分組",
|
||||
"untagged": "未分組"
|
||||
},
|
||||
"title": "助手",
|
||||
|
||||
@@ -65,8 +65,10 @@ function buildTags(resources: ResourceItem[], backendTags: Tag[], filterType?: R
|
||||
for (const tag of r.raw.tags ?? []) {
|
||||
if (!backendTagByName.has(tag.name)) backendTagByName.set(tag.name, tag)
|
||||
}
|
||||
if (r.tag) {
|
||||
tagMap.set(r.tag, (tagMap.get(r.tag) || 0) + 1)
|
||||
}
|
||||
}
|
||||
r.tags.forEach((t) => tagMap.set(t, (tagMap.get(t) || 0) + 1))
|
||||
})
|
||||
return Array.from(tagMap.entries())
|
||||
.sort((a, b) => b[1] - a[1])
|
||||
@@ -155,7 +157,6 @@ export default function LibraryPage() {
|
||||
[tagList.tags]
|
||||
)
|
||||
|
||||
const noop = useCallback(() => {}, [])
|
||||
const handleClosePromptDialog = useCallback(() => {
|
||||
setPromptDialog(null)
|
||||
}, [])
|
||||
@@ -474,7 +475,6 @@ export default function LibraryPage() {
|
||||
// rows; binding stays inside card/dialog tag hooks.
|
||||
await ensureTags([tagName])
|
||||
}}
|
||||
onUpdateResourceTags={noop /* binding is executed inside FixedCardMenu via the tag hooks */}
|
||||
allTagNames={allTagNames}
|
||||
allTags={tagList.tags}
|
||||
assistantCatalog={assistantCatalogProp}
|
||||
|
||||
@@ -516,7 +516,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Assistant to duplicate',
|
||||
description: '',
|
||||
avatar: '💬',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'assistant-to-duplicate', name: 'Assistant to duplicate', tags: [] }
|
||||
@@ -630,7 +629,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Selector Agent',
|
||||
description: '',
|
||||
avatar: '',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'agent-from-selector' }
|
||||
@@ -652,7 +650,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Selector Agent',
|
||||
description: '',
|
||||
avatar: '',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'agent-from-selector' }
|
||||
@@ -675,7 +672,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Selector Assistant',
|
||||
description: '',
|
||||
avatar: '💬',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'assistant-from-selector' }
|
||||
@@ -697,7 +693,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Stale Assistant',
|
||||
description: '',
|
||||
avatar: '💬',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'assistant-stale-tags', name: 'Stale Assistant' }
|
||||
@@ -717,7 +712,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Grid Prompt',
|
||||
description: '',
|
||||
avatar: 'Aa',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'prompt-from-grid' }
|
||||
@@ -746,7 +740,6 @@ describe('LibraryPage create flow', () => {
|
||||
name: 'Grid Skill',
|
||||
description: '',
|
||||
avatar: 'S',
|
||||
tags: [],
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
updatedAt: '2024-01-01T00:00:00.000Z',
|
||||
raw: { id: 'skill-from-grid', name: 'Grid Skill' }
|
||||
|
||||
@@ -78,7 +78,7 @@ describe('useAssistantMutations', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('forwards tag ids to the create endpoint when duplicating an assistant', async () => {
|
||||
it('forwards only one tag id to the create endpoint when duplicating an assistant', async () => {
|
||||
const created = createAssistant({ id: 'ast-copy', tags: [] })
|
||||
createTriggerMock.mockResolvedValue(created)
|
||||
|
||||
@@ -106,7 +106,7 @@ describe('useAssistantMutations', () => {
|
||||
settings: source.settings,
|
||||
mcpServerIds: ['mcp-1'],
|
||||
knowledgeBaseIds: ['kb-1'],
|
||||
tagIds: ['tag-1', 'tag-2']
|
||||
tagIds: ['tag-1']
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -66,6 +66,7 @@ export function useAssistantMutations() {
|
||||
const duplicateAssistant = useCallback(
|
||||
async (source: Assistant): Promise<Assistant> => {
|
||||
const duplicateName = t('library.duplicate_name', { name: source.name })
|
||||
const tagId = source.tags[0]?.id
|
||||
|
||||
return createTrigger({
|
||||
body: {
|
||||
@@ -77,7 +78,7 @@ export function useAssistantMutations() {
|
||||
settings: source.settings,
|
||||
mcpServerIds: source.mcpServerIds,
|
||||
knowledgeBaseIds: source.knowledgeBaseIds,
|
||||
tagIds: source.tags.map((tag) => tag.id)
|
||||
tagIds: tagId ? [tagId] : []
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { ResourceType } from '../types'
|
||||
export interface ResourceListQuery {
|
||||
/** Free-text match against name OR description (passed through to the API). */
|
||||
search?: string
|
||||
/** Union (OR) tag filter — kept if the resource is bound to ANY of these tag ids. */
|
||||
/** Backend tag-id filter transport shape; current assistant UI passes at most one id. */
|
||||
tagIds?: string[]
|
||||
limit?: number
|
||||
offset?: number
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Input,
|
||||
MenuDivider,
|
||||
MenuItem,
|
||||
@@ -12,8 +11,8 @@ import {
|
||||
} from '@cherrystudio/ui'
|
||||
import { loggerService } from '@logger'
|
||||
import { useEnsureTags, useTagList } from '@renderer/hooks/useTags'
|
||||
import { ChevronDown, Copy, Download, Plus, Tag, Trash2 } from 'lucide-react'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { Check, ChevronDown, Copy, Download, Plus, Tag, Trash2 } from 'lucide-react'
|
||||
import { type KeyboardEvent, useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { useAssistantMutationsById } from '../adapters/assistantAdapter'
|
||||
@@ -32,7 +31,6 @@ interface ResourceCardMenuProps {
|
||||
onDuplicate: (r: ResourceItem) => void
|
||||
onDelete: (r: ResourceItem) => void
|
||||
onExport: (r: ResourceItem) => void
|
||||
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
|
||||
allTagNames: string[]
|
||||
}
|
||||
|
||||
@@ -42,16 +40,19 @@ export function ResourceCardMenu({
|
||||
onDuplicate,
|
||||
onDelete,
|
||||
onExport,
|
||||
onUpdateResourceTags,
|
||||
allTagNames
|
||||
}: ResourceCardMenuProps) {
|
||||
const { t } = useTranslation()
|
||||
const [showTagPicker, setShowTagPicker] = useState(false)
|
||||
const [localTags, setLocalTags] = useState<string[]>(resource.tags)
|
||||
const [localTag, setLocalTag] = useState<string | null>(() =>
|
||||
resource.type === 'assistant' ? (resource.tag ?? null) : null
|
||||
)
|
||||
const [tagInput, setTagInput] = useState('')
|
||||
const [bindingError, setBindingError] = useState<string | null>(null)
|
||||
const [bindingPending, setBindingPending] = useState(false)
|
||||
const bindingPendingRef = useRef(false)
|
||||
const tagOptionRefs = useRef<Array<HTMLDivElement | null>>([])
|
||||
const [activeTagIndex, setActiveTagIndex] = useState(0)
|
||||
|
||||
const { ensureTags } = useEnsureTags({ getDefaultColor: getRandomTagColor })
|
||||
const { updateAssistant } = useAssistantMutationsById(resource.id)
|
||||
@@ -65,22 +66,28 @@ export function ResourceCardMenu({
|
||||
const tagList = useTagList()
|
||||
const colorFor = (name: string): string => tagList.tags.find((tag) => tag.name === name)?.color ?? DEFAULT_TAG_COLOR
|
||||
|
||||
const persistTags = useCallback(
|
||||
async (nextNames: string[], previousNames: string[]) => {
|
||||
useEffect(() => {
|
||||
if (!showTagPicker) return
|
||||
const selectedIndex = localTag ? allTagNames.indexOf(localTag) : -1
|
||||
setActiveTagIndex(selectedIndex >= 0 ? selectedIndex : 0)
|
||||
}, [allTagNames, localTag, showTagPicker])
|
||||
|
||||
const persistTag = useCallback(
|
||||
async (nextName: string | null, previousName: string | null) => {
|
||||
if (!canBindTags) return
|
||||
if (bindingPendingRef.current) return
|
||||
bindingPendingRef.current = true
|
||||
setBindingPending(true)
|
||||
try {
|
||||
const nextNames = nextName ? [nextName] : []
|
||||
const tags = await ensureTags(nextNames)
|
||||
const tagIds = tags.map((tag) => tag.id)
|
||||
if (resource.type === 'assistant') {
|
||||
await updateAssistant({ tagIds })
|
||||
}
|
||||
onUpdateResourceTags(resource.id, nextNames)
|
||||
} catch (e) {
|
||||
// Roll back optimistic state on failure.
|
||||
setLocalTags(previousNames)
|
||||
setLocalTag(previousName)
|
||||
const message = e instanceof Error ? e.message : t('library.tag_sync_failed')
|
||||
setBindingError(message)
|
||||
// The inline error text only renders while the popup is open. Toast +
|
||||
@@ -96,31 +103,60 @@ export function ResourceCardMenu({
|
||||
setBindingPending(false)
|
||||
}
|
||||
},
|
||||
[canBindTags, ensureTags, updateAssistant, onUpdateResourceTags, resource.id, resource.type, t]
|
||||
[canBindTags, ensureTags, updateAssistant, resource.id, resource.type, t]
|
||||
)
|
||||
|
||||
const toggleTag = (tag: string) => {
|
||||
if (bindingPendingRef.current) return
|
||||
const prev = localTags
|
||||
const next = prev.includes(tag) ? prev.filter((item) => item !== tag) : [...prev, tag]
|
||||
setLocalTags(next)
|
||||
const prev = localTag
|
||||
const next = prev === tag ? null : tag
|
||||
setLocalTag(next)
|
||||
setBindingError(null)
|
||||
void persistTags(next, prev)
|
||||
void persistTag(next, prev)
|
||||
}
|
||||
|
||||
const focusTagOption = (index: number) => {
|
||||
setActiveTagIndex(index)
|
||||
tagOptionRefs.current[index]?.focus()
|
||||
}
|
||||
|
||||
const handleTagOptionKeyDown = (e: KeyboardEvent<HTMLDivElement>, index: number, tag: string) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault()
|
||||
if (!bindingPending) toggleTag(tag)
|
||||
return
|
||||
}
|
||||
|
||||
if (allTagNames.length === 0) return
|
||||
|
||||
if (e.key === 'ArrowDown') {
|
||||
e.preventDefault()
|
||||
focusTagOption((index + 1) % allTagNames.length)
|
||||
} else if (e.key === 'ArrowUp') {
|
||||
e.preventDefault()
|
||||
focusTagOption((index - 1 + allTagNames.length) % allTagNames.length)
|
||||
} else if (e.key === 'Home') {
|
||||
e.preventDefault()
|
||||
focusTagOption(0)
|
||||
} else if (e.key === 'End') {
|
||||
e.preventDefault()
|
||||
focusTagOption(allTagNames.length - 1)
|
||||
}
|
||||
}
|
||||
|
||||
const addNewTag = () => {
|
||||
if (bindingPendingRef.current) return
|
||||
const tag = tagInput.trim()
|
||||
if (!tag || localTags.includes(tag)) {
|
||||
if (!tag || localTag === tag) {
|
||||
setTagInput('')
|
||||
return
|
||||
}
|
||||
const prev = localTags
|
||||
const next = [...prev, tag]
|
||||
setLocalTags(next)
|
||||
const prev = localTag
|
||||
const next = tag
|
||||
setLocalTag(next)
|
||||
setTagInput('')
|
||||
setBindingError(null)
|
||||
void persistTags(next, prev)
|
||||
void persistTag(next, prev)
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -137,9 +173,7 @@ export function ResourceCardMenu({
|
||||
label={t('library.action.manage_tags')}
|
||||
suffix={
|
||||
<>
|
||||
{localTags.length > 0 && (
|
||||
<span className="text-foreground-muted text-xs tabular-nums">{localTags.length}</span>
|
||||
)}
|
||||
{localTag && <span className="text-foreground-muted text-xs tabular-nums">1</span>}
|
||||
<ChevronDown size={8} className={`transition-transform ${showTagPicker ? 'rotate-180' : ''}`} />
|
||||
</>
|
||||
}
|
||||
@@ -174,43 +208,38 @@ export function ResourceCardMenu({
|
||||
)}
|
||||
</div>
|
||||
<Separator className="mx-1 mb-0.5 bg-border-subtle" />
|
||||
<div className="flex-1 overflow-y-auto [&::-webkit-scrollbar-thumb]:bg-border-muted [&::-webkit-scrollbar]:w-0.5">
|
||||
<div
|
||||
role="menu"
|
||||
aria-label={t('library.config.basic.tags')}
|
||||
className="flex-1 overflow-y-auto [&::-webkit-scrollbar-thumb]:bg-border-muted [&::-webkit-scrollbar]:w-0.5">
|
||||
{allTagNames.length === 0 && !tagInput.trim() && (
|
||||
<p className="px-2.5 py-2 text-center text-foreground-muted text-xs">
|
||||
{t('library.tag_picker.no_tags')}
|
||||
</p>
|
||||
)}
|
||||
{allTagNames.map((tag) => {
|
||||
const checked = localTags.includes(tag)
|
||||
{allTagNames.map((tag, index) => {
|
||||
const checked = localTag === tag
|
||||
return (
|
||||
<div
|
||||
key={tag}
|
||||
role="button"
|
||||
tabIndex={bindingPending ? -1 : 0}
|
||||
ref={(node) => {
|
||||
tagOptionRefs.current[index] = node
|
||||
}}
|
||||
role="menuitemradio"
|
||||
aria-checked={checked}
|
||||
tabIndex={!bindingPending && index === activeTagIndex ? 0 : -1}
|
||||
aria-disabled={bindingPending || undefined}
|
||||
onClick={() => toggleTag(tag)}
|
||||
onKeyDown={(e) => {
|
||||
if (bindingPending) return
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault()
|
||||
toggleTag(tag)
|
||||
}
|
||||
}}
|
||||
onFocus={() => setActiveTagIndex(index)}
|
||||
onKeyDown={(e) => handleTagOptionKeyDown(e, index, tag)}
|
||||
className={`flex w-full items-center gap-2 rounded-md px-2.5 py-1 text-foreground-secondary text-xs transition-colors ${
|
||||
bindingPending
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: 'cursor-pointer hover:bg-accent hover:text-foreground'
|
||||
}`}>
|
||||
<span onClick={(e) => e.stopPropagation()}>
|
||||
<Checkbox
|
||||
size="sm"
|
||||
checked={checked}
|
||||
disabled={bindingPending}
|
||||
onCheckedChange={() => toggleTag(tag)}
|
||||
/>
|
||||
</span>
|
||||
<span className="h-1.5 w-1.5 shrink-0 rounded-full" style={{ backgroundColor: colorFor(tag) }} />
|
||||
<span className="flex-1 truncate text-left">{tag}</span>
|
||||
{checked && <Check size={12} className="shrink-0 text-success" />}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
@@ -214,22 +214,13 @@ interface ResourceCardProps {
|
||||
onDuplicate: (resource: ResourceItem) => void
|
||||
onEdit: (resource: ResourceItem) => void
|
||||
onExport: (resource: ResourceItem) => void
|
||||
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
|
||||
}
|
||||
|
||||
function hasOverflowActions(resource: ResourceItem) {
|
||||
return resource.type === 'assistant'
|
||||
}
|
||||
|
||||
export function ResourceCard({
|
||||
resource: r,
|
||||
allTagNames,
|
||||
onDelete,
|
||||
onDuplicate,
|
||||
onEdit,
|
||||
onExport,
|
||||
onUpdateResourceTags
|
||||
}: ResourceCardProps) {
|
||||
export function ResourceCard({ resource: r, allTagNames, onDelete, onDuplicate, onEdit, onExport }: ResourceCardProps) {
|
||||
const { t } = useTranslation()
|
||||
const [menuOpen, setMenuOpen] = useState(false)
|
||||
const cfg = RESOURCE_TYPE_META[r.type]
|
||||
@@ -237,8 +228,7 @@ export function ResourceCard({
|
||||
// other resources keep their own avatar on the neutral accent block.
|
||||
const useTypedAvatarBg = r.type === 'skill'
|
||||
const showOverflowMenu = hasOverflowActions(r)
|
||||
const visibleTags = r.type === 'assistant' ? r.tags.slice(0, 2) : []
|
||||
const extraTagCount = r.type === 'assistant' ? r.tags.length - visibleTags.length : 0
|
||||
const visibleTag = r.type === 'assistant' ? r.tag : undefined
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -259,17 +249,13 @@ export function ResourceCard({
|
||||
<div className="min-w-0 flex-1">
|
||||
<h4 className="truncate font-medium text-foreground text-sm leading-5">{r.name}</h4>
|
||||
<p className="mt-0.5 truncate text-foreground-secondary text-xs leading-4">{r.description}</p>
|
||||
{visibleTags.length > 0 && (
|
||||
{visibleTag && (
|
||||
<div className="mt-1.5 flex min-w-0 items-center gap-1">
|
||||
{visibleTags.map((tag, i) => (
|
||||
<Badge
|
||||
key={`${tag}-${i}`}
|
||||
variant="secondary"
|
||||
className="max-w-24 truncate border-0 bg-secondary px-1.5 py-px text-foreground-secondary text-xs">
|
||||
{tag}
|
||||
</Badge>
|
||||
))}
|
||||
{extraTagCount > 0 && <span className="shrink-0 text-foreground-muted text-xs">+{extraTagCount}</span>}
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="max-w-24 truncate border-0 bg-secondary px-1.5 py-px text-foreground-secondary text-xs">
|
||||
{visibleTag}
|
||||
</Badge>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -299,7 +285,6 @@ export function ResourceCard({
|
||||
onDuplicate={onDuplicate}
|
||||
onDelete={onDelete}
|
||||
onExport={onExport}
|
||||
onUpdateResourceTags={onUpdateResourceTags}
|
||||
allTagNames={allTagNames}
|
||||
/>
|
||||
</PopoverContent>
|
||||
|
||||
@@ -69,8 +69,6 @@ interface Props {
|
||||
onTagFilter: (tagName: string | null) => void
|
||||
/** Create a new tag (POST /tags). Does not bind the tag to any resource. */
|
||||
onAddTag: (tagName: string) => Promise<void> | void
|
||||
/** Replace the tag-name set for a single resource. Caller handles ensure-tag + bind. */
|
||||
onUpdateResourceTags: (resourceId: string, tags: string[]) => Promise<void> | void
|
||||
allTagNames: string[]
|
||||
/** Full backend tag records (id + name + color). Distinct from `allTagNames` (names only). */
|
||||
allTags: BackendTag[]
|
||||
@@ -128,7 +126,6 @@ export const ResourceGrid: FC<Props> = ({
|
||||
activeTag,
|
||||
onTagFilter,
|
||||
onAddTag,
|
||||
onUpdateResourceTags,
|
||||
allTagNames,
|
||||
allTags,
|
||||
assistantCatalog
|
||||
@@ -494,7 +491,6 @@ export const ResourceGrid: FC<Props> = ({
|
||||
onDuplicate={onDuplicate}
|
||||
onEdit={onEdit}
|
||||
onExport={onExport}
|
||||
onUpdateResourceTags={onUpdateResourceTags}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -534,7 +530,6 @@ interface VirtualizedResourceGridProps {
|
||||
onDuplicate: (r: ResourceItem) => void
|
||||
onEdit: (r: ResourceItem) => void
|
||||
onExport: (r: ResourceItem) => void
|
||||
onUpdateResourceTags: (resourceId: string, tags: string[]) => void
|
||||
}
|
||||
|
||||
function VirtualizedResourceGrid({
|
||||
@@ -545,8 +540,7 @@ function VirtualizedResourceGrid({
|
||||
onDelete,
|
||||
onDuplicate,
|
||||
onEdit,
|
||||
onExport,
|
||||
onUpdateResourceTags
|
||||
onExport
|
||||
}: VirtualizedResourceGridProps) {
|
||||
const rows = useMemo(() => {
|
||||
const nextRows: ResourceItem[][] = []
|
||||
@@ -590,7 +584,6 @@ function VirtualizedResourceGrid({
|
||||
onDuplicate={onDuplicate}
|
||||
onEdit={onEdit}
|
||||
onExport={onExport}
|
||||
onUpdateResourceTags={onUpdateResourceTags}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -75,27 +75,6 @@ vi.mock('@cherrystudio/ui', async () => {
|
||||
</button>
|
||||
)
|
||||
},
|
||||
Checkbox: ({
|
||||
checked = false,
|
||||
onCheckedChange,
|
||||
size,
|
||||
...props
|
||||
}: Omit<ComponentProps<'button'>, 'onChange'> & {
|
||||
checked?: boolean
|
||||
onCheckedChange?: (checked: boolean) => void
|
||||
size?: string
|
||||
}) => {
|
||||
void size
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
role="checkbox"
|
||||
aria-checked={checked}
|
||||
onClick={() => onCheckedChange?.(!checked)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
},
|
||||
ConfirmDialog: ({
|
||||
cancelText,
|
||||
confirmText,
|
||||
@@ -276,7 +255,6 @@ function createAssistantResource(overrides: Partial<Extract<ResourceItem, { type
|
||||
name: 'Assistant',
|
||||
description: '',
|
||||
avatar: 'A',
|
||||
tags: [],
|
||||
createdAt: '2026-05-06T00:00:00.000Z',
|
||||
updatedAt: '2026-05-06T00:00:00.000Z',
|
||||
raw: {} as Extract<ResourceItem, { type: 'assistant' }>['raw'],
|
||||
@@ -291,7 +269,6 @@ function createAgentResource(): ResourceItem {
|
||||
name: 'Agent',
|
||||
description: '',
|
||||
avatar: 'A',
|
||||
tags: [],
|
||||
createdAt: '2026-05-06T00:00:00.000Z',
|
||||
updatedAt: '2026-05-06T00:00:00.000Z',
|
||||
raw: {} as Extract<ResourceItem, { type: 'agent' }>['raw']
|
||||
@@ -305,7 +282,6 @@ function createSkillResource(): ResourceItem {
|
||||
name: 'Skill',
|
||||
description: '',
|
||||
avatar: 'S',
|
||||
tags: [],
|
||||
createdAt: '2026-05-06T00:00:00.000Z',
|
||||
updatedAt: '2026-05-06T00:00:00.000Z',
|
||||
raw: {} as Extract<ResourceItem, { type: 'skill' }>['raw']
|
||||
@@ -319,7 +295,6 @@ function createPromptResource(): ResourceItem {
|
||||
name: 'Prompt',
|
||||
description: '',
|
||||
avatar: 'Aa',
|
||||
tags: [],
|
||||
createdAt: '2026-05-06T00:00:00.000Z',
|
||||
updatedAt: '2026-05-06T00:00:00.000Z',
|
||||
raw: {} as Extract<ResourceItem, { type: 'prompt' }>['raw']
|
||||
@@ -344,7 +319,6 @@ function renderResourceGrid(props: Partial<ComponentProps<typeof ResourceGrid>>
|
||||
activeTag={null}
|
||||
onTagFilter={vi.fn()}
|
||||
onAddTag={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={[]}
|
||||
allTags={[]}
|
||||
{...props}
|
||||
@@ -359,7 +333,6 @@ function getResourceCardProps(overrides: Partial<ComponentProps<typeof ResourceC
|
||||
onDuplicate: vi.fn(),
|
||||
onEdit: vi.fn(),
|
||||
onExport: vi.fn(),
|
||||
onUpdateResourceTags: vi.fn(),
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
@@ -505,17 +478,12 @@ describe('ResourceGrid card actions', () => {
|
||||
expect(onDelete).toHaveBeenCalledWith(resource)
|
||||
})
|
||||
|
||||
it('keeps assistant tags visible in the compact card layout', () => {
|
||||
render(
|
||||
<ResourceCard
|
||||
resource={createAssistantResource({ tags: ['alpha', 'beta', 'gamma'] })}
|
||||
{...getResourceCardProps()}
|
||||
/>
|
||||
)
|
||||
it('shows only one assistant tag in the compact card layout', () => {
|
||||
render(<ResourceCard resource={createAssistantResource({ tag: 'alpha' })} {...getResourceCardProps()} />)
|
||||
|
||||
expect(screen.getByText('alpha')).toBeInTheDocument()
|
||||
expect(screen.getByText('beta')).toBeInTheDocument()
|
||||
expect(screen.getByText('+1')).toBeInTheDocument()
|
||||
expect(screen.queryByText('beta')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('+2')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -624,7 +592,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
const pendingTags = createDeferred<Array<{ id: string; name: string }>>()
|
||||
ensureTagsMock.mockReturnValueOnce(pendingTags.promise)
|
||||
updateAssistantMock.mockResolvedValue({})
|
||||
const onUpdateResourceTags = vi.fn()
|
||||
|
||||
render(
|
||||
<ResourceCardMenu
|
||||
@@ -633,17 +600,20 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={onUpdateResourceTags}
|
||||
allTagNames={['alpha', 'beta']}
|
||||
/>
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
|
||||
const checkboxes = screen.getAllByRole('checkbox')
|
||||
await user.click(checkboxes[0])
|
||||
expect(screen.getByRole('menu', { name: 'library.config.basic.tags' })).toContainElement(
|
||||
screen.getByRole('menuitemradio', { name: 'alpha' })
|
||||
)
|
||||
await user.click(screen.getByRole('menuitemradio', { name: 'alpha' }))
|
||||
|
||||
await waitFor(() => expect(checkboxes[1]).toBeDisabled())
|
||||
await user.click(checkboxes[1])
|
||||
await waitFor(() =>
|
||||
expect(screen.getByRole('menuitemradio', { name: 'beta' })).toHaveAttribute('aria-disabled', 'true')
|
||||
)
|
||||
await user.click(screen.getByRole('menuitemradio', { name: 'beta' }))
|
||||
expect(ensureTagsMock).toHaveBeenCalledTimes(1)
|
||||
|
||||
pendingTags.resolve([{ id: 'tag-alpha', name: 'alpha' }])
|
||||
@@ -651,10 +621,73 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
await waitFor(() => {
|
||||
expect(updateAssistantMock).toHaveBeenCalledWith({ tagIds: ['tag-alpha'] })
|
||||
})
|
||||
expect(onUpdateResourceTags).toHaveBeenCalledWith('assistant-1', ['alpha'])
|
||||
expect(ensureTagsMock).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('uses roving focus for tag picker radio items', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<ResourceCardMenu
|
||||
resource={createAssistantResource()}
|
||||
onClose={vi.fn()}
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
allTagNames={['alpha', 'beta', 'gamma']}
|
||||
/>
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
|
||||
|
||||
const alpha = screen.getByRole('menuitemradio', { name: 'alpha' })
|
||||
const beta = screen.getByRole('menuitemradio', { name: 'beta' })
|
||||
const gamma = screen.getByRole('menuitemradio', { name: 'gamma' })
|
||||
|
||||
expect(alpha).toHaveAttribute('tabindex', '0')
|
||||
expect(beta).toHaveAttribute('tabindex', '-1')
|
||||
|
||||
alpha.focus()
|
||||
expect(alpha).toHaveFocus()
|
||||
|
||||
await user.keyboard('{ArrowDown}')
|
||||
expect(beta).toHaveFocus()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(alpha).toHaveAttribute('tabindex', '-1')
|
||||
expect(beta).toHaveAttribute('tabindex', '0')
|
||||
})
|
||||
|
||||
await user.keyboard('{End}')
|
||||
expect(gamma).toHaveFocus()
|
||||
|
||||
await user.keyboard('{Home}')
|
||||
expect(alpha).toHaveFocus()
|
||||
})
|
||||
|
||||
it('replaces the current assistant tag when a different tag is selected', async () => {
|
||||
const user = userEvent.setup()
|
||||
ensureTagsMock.mockResolvedValueOnce([{ id: 'tag-beta', name: 'beta' }])
|
||||
updateAssistantMock.mockResolvedValue({})
|
||||
|
||||
render(
|
||||
<ResourceCardMenu
|
||||
resource={createAssistantResource({ tag: 'alpha' })}
|
||||
onClose={vi.fn()}
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
allTagNames={['alpha', 'beta']}
|
||||
/>
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /library.action.manage_tags/ }))
|
||||
await user.click(screen.getByRole('menuitemradio', { name: 'beta' }))
|
||||
|
||||
await waitFor(() => expect(ensureTagsMock).toHaveBeenCalledWith(['beta']))
|
||||
expect(updateAssistantMock).toHaveBeenCalledWith({ tagIds: ['tag-beta'] })
|
||||
})
|
||||
|
||||
it('does not expose tag management for agent, skill, or prompt resources', () => {
|
||||
const { rerender } = render(
|
||||
<ResourceCardMenu
|
||||
@@ -663,7 +696,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={['alpha', 'beta']}
|
||||
/>
|
||||
)
|
||||
@@ -677,7 +709,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={['alpha', 'beta']}
|
||||
/>
|
||||
)
|
||||
@@ -691,7 +722,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={['alpha', 'beta']}
|
||||
/>
|
||||
)
|
||||
@@ -707,7 +737,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={[]}
|
||||
/>
|
||||
)
|
||||
@@ -724,7 +753,6 @@ describe('ResourceCardMenu tag binding', () => {
|
||||
onDuplicate={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
onExport={vi.fn()}
|
||||
onUpdateResourceTags={vi.fn()}
|
||||
allTagNames={[]}
|
||||
/>
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ describe('useResourceLibrary model display names', () => {
|
||||
})
|
||||
const skill = result.current.allResources.find((resource) => resource.type === 'skill')
|
||||
|
||||
expect(skill?.tags).toEqual([])
|
||||
expect(skill?.tag).toBeUndefined()
|
||||
})
|
||||
|
||||
it('passes skill search to the backend and ignores activeTag', () => {
|
||||
@@ -277,8 +277,7 @@ describe('useResourceLibrary model display names', () => {
|
||||
type: 'prompt',
|
||||
name: '日报模板',
|
||||
description: '今日完成 ${task}',
|
||||
avatar: 'Aa',
|
||||
tags: []
|
||||
avatar: 'Aa'
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
@@ -84,13 +84,9 @@ export function useResourceLibrary({
|
||||
// the param when nothing resolves rather than sending a 400.
|
||||
const tagIds = useMemo(() => {
|
||||
if (!assistantTagsActive) return undefined
|
||||
const names = [activeTag].filter((x): x is string => Boolean(x))
|
||||
if (names.length === 0) return undefined
|
||||
const ids = names.flatMap((name) => {
|
||||
const id = tagIdByName.get(name)
|
||||
return id ? [id] : []
|
||||
})
|
||||
return ids.length > 0 ? ids : undefined
|
||||
if (!activeTag) return undefined
|
||||
const id = tagIdByName.get(activeTag)
|
||||
return id ? [id] : undefined
|
||||
}, [activeTag, assistantTagsActive, tagIdByName])
|
||||
|
||||
// Defensive guard for the rare race where the user has a chip selected but
|
||||
@@ -112,10 +108,10 @@ export function useResourceLibrary({
|
||||
const filteredPrompts = promptAdapter.useList(promptsVisible ? { search: trimmedSearch } : undefined)
|
||||
|
||||
const buildAssistantItem = useCallback((a: Assistant): ResourceItem => {
|
||||
// Defensive `?? []`: schema declares tags as required, but stale DataApi
|
||||
// cache or a row from a code path that bypasses the embed helper can
|
||||
// still hand us undefined here. `.map` would throw.
|
||||
const tags = a.tags ?? []
|
||||
// Defensive optional access: schema declares tags as required, but stale DataApi
|
||||
// cache or a row from a code path that bypasses the embed helper can still hand
|
||||
// us undefined here.
|
||||
const tag = a.tags?.[0]
|
||||
return {
|
||||
id: a.id,
|
||||
type: 'assistant',
|
||||
@@ -125,7 +121,7 @@ export function useResourceLibrary({
|
||||
// Embedded by AssistantService.list via JOIN on user_model; null when the
|
||||
// bound model row was removed.
|
||||
model: a.modelName ?? undefined,
|
||||
tags: tags.map((t) => t.name),
|
||||
tag: tag?.name,
|
||||
createdAt: a.createdAt,
|
||||
updatedAt: a.updatedAt,
|
||||
raw: a
|
||||
@@ -140,7 +136,6 @@ export function useResourceLibrary({
|
||||
description: a.description ?? '',
|
||||
avatar: getAgentAvatarFromConfiguration(a.configuration),
|
||||
model: a.modelName ?? undefined,
|
||||
tags: [],
|
||||
createdAt: a.createdAt,
|
||||
updatedAt: a.updatedAt,
|
||||
raw: a
|
||||
@@ -157,7 +152,6 @@ export function useResourceLibrary({
|
||||
avatar: '⚡',
|
||||
// Skill metadata tags from SKILL.md live on `sourceTags`; the outer
|
||||
// resource-library user tag concept is assistant-only.
|
||||
tags: [],
|
||||
createdAt: s.createdAt,
|
||||
updatedAt: s.updatedAt,
|
||||
raw: s
|
||||
@@ -171,7 +165,6 @@ export function useResourceLibrary({
|
||||
name: p.title,
|
||||
description: p.content.replace(/\s+/g, ' ').trim(),
|
||||
avatar: 'Aa',
|
||||
tags: [],
|
||||
createdAt: p.createdAt,
|
||||
updatedAt: p.updatedAt,
|
||||
raw: p
|
||||
|
||||
@@ -19,17 +19,16 @@ interface ResourceItemBase<TType extends ResourceType, TRaw> {
|
||||
description: string
|
||||
avatar: string
|
||||
model?: string
|
||||
tags: string[]
|
||||
createdAt: string
|
||||
updatedAt: string
|
||||
raw: TRaw
|
||||
}
|
||||
|
||||
export type ResourceItem =
|
||||
| ResourceItemBase<'assistant', Assistant>
|
||||
| ResourceItemBase<'agent', AgentDetail>
|
||||
| ResourceItemBase<'skill', InstalledSkill>
|
||||
| ResourceItemBase<'prompt', Prompt>
|
||||
| (ResourceItemBase<'assistant', Assistant> & { tag?: string })
|
||||
| (ResourceItemBase<'agent', AgentDetail> & { tag?: never })
|
||||
| (ResourceItemBase<'skill', InstalledSkill> & { tag?: never })
|
||||
| (ResourceItemBase<'prompt', Prompt> & { tag?: never })
|
||||
|
||||
export interface TagItem {
|
||||
id: string
|
||||
|
||||
@@ -60,7 +60,7 @@ describe('assistantTransfer', () => {
|
||||
{
|
||||
name: '写作助手',
|
||||
emoji: '✍️',
|
||||
group: ['写作', '生产力'],
|
||||
group: ['写作'],
|
||||
prompt: 'You are helpful',
|
||||
description: '擅长写作润色',
|
||||
regularPhrases: [],
|
||||
@@ -89,10 +89,7 @@ describe('assistantTransfer', () => {
|
||||
// modelId is intentionally not part of the DTO — the backend fills it from
|
||||
// the `chat.default_model_id` preference during create.
|
||||
expect(draft.dto.modelId).toBeUndefined()
|
||||
expect(draft.tags).toEqual([
|
||||
{ name: '写作', color: null },
|
||||
{ name: '生产力', color: null }
|
||||
])
|
||||
expect(draft.tags).toEqual([{ name: '写作', color: null }])
|
||||
})
|
||||
|
||||
it('ignores v2-only fields from imported content and still uses legacy defaults', () => {
|
||||
|
||||
@@ -55,6 +55,8 @@ function normalizeRecord(record: unknown): ImportedAssistantDraft {
|
||||
throw new AssistantTransferError('invalid_format')
|
||||
}
|
||||
|
||||
const tagName = readStringArray(record.group)[0]
|
||||
|
||||
// `modelId` is intentionally omitted — backend fills it from
|
||||
// `chat.default_model_id` preference. See AssistantService.resolveCreateModelId.
|
||||
return {
|
||||
@@ -65,18 +67,17 @@ function normalizeRecord(record: unknown): ImportedAssistantDraft {
|
||||
description: readString(record.description),
|
||||
settings: DEFAULT_ASSISTANT_SETTINGS
|
||||
},
|
||||
tags: readStringArray(record.group).map((tagName) => ({
|
||||
name: tagName,
|
||||
color: null
|
||||
}))
|
||||
tags: tagName ? [{ name: tagName, color: null }] : []
|
||||
}
|
||||
}
|
||||
|
||||
function buildExportRecord(assistant: Assistant): AssistantExportRecord {
|
||||
const tagName = assistant.tags[0]?.name
|
||||
|
||||
return {
|
||||
name: assistant.name,
|
||||
emoji: assistant.emoji,
|
||||
group: assistant.tags.map((tag) => tag.name),
|
||||
group: tagName ? [tagName] : [],
|
||||
prompt: assistant.prompt,
|
||||
description: assistant.description,
|
||||
regularPhrases: [],
|
||||
|
||||
@@ -122,16 +122,92 @@ describe('useAutoPullOnApiKeyChange', () => {
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not fire when no models exist locally yet', () => {
|
||||
it('fires on first render + key change for API-key providers with no models (auto-sync is disabled)', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
|
||||
useModelsMock.mockReturnValue({ models: [] })
|
||||
|
||||
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
|
||||
|
||||
// First render with keys + no models: opens pull reconcile.
|
||||
expect(onTrigger).toHaveBeenCalledTimes(1)
|
||||
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-two'))
|
||||
rerender()
|
||||
|
||||
// Key change with no models: opens pull reconcile again.
|
||||
expect(onTrigger).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('does not fire for non-key providers when no models exist (auto-sync handles bootstrap)', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
|
||||
useModelsMock.mockReturnValue({ models: [] })
|
||||
useProviderMock.mockReturnValue(providerWithHost('http://localhost:11434', 'ollama'))
|
||||
|
||||
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('ollama', onTrigger))
|
||||
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-two'))
|
||||
rerender()
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('fires on first render for API-key providers with no models and enabled keys', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
|
||||
useModelsMock.mockReturnValue({ models: [] })
|
||||
|
||||
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
|
||||
|
||||
expect(onTrigger).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('does not fire on first render for API-key providers that already have models', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
|
||||
useModelsMock.mockReturnValue({ models: [{ id: 'openai::gpt-4o' }] })
|
||||
|
||||
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not fire on first render for API-key providers without enabled keys', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(emptyApiKeys())
|
||||
useModelsMock.mockReturnValue({ models: [] })
|
||||
|
||||
renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not fire on first render for non-key providers with no models (auto-sync handles it)', () => {
|
||||
const onTrigger = vi.fn()
|
||||
useProviderApiKeysMock.mockReturnValue(emptyApiKeys())
|
||||
useModelsMock.mockReturnValue({ models: [] })
|
||||
useProviderMock.mockReturnValue(providerWithHost('http://localhost:11434', 'ollama'))
|
||||
|
||||
renderHook(() => useAutoPullOnApiKeyChange('ollama', onTrigger))
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('waits for models to finish loading before deciding first-render pull reconcile', () => {
|
||||
const onTrigger = vi.fn()
|
||||
// api-keys resolve first, models are still loading.
|
||||
useProviderApiKeysMock.mockReturnValue(apiKeys('sk-one'))
|
||||
useModelsMock.mockReturnValue({ models: [], isLoading: true })
|
||||
|
||||
const { rerender } = renderHook(() => useAutoPullOnApiKeyChange('openai', onTrigger))
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
|
||||
// models resolve later — the provider already has local models.
|
||||
useModelsMock.mockReturnValue({ models: [{ id: 'openai::gpt-4o' }], isLoading: false })
|
||||
rerender()
|
||||
|
||||
expect(onTrigger).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
||||
@@ -7,16 +7,16 @@ import { providerNeedsApiKeyForModelSync } from './providerModelSyncRequirements
|
||||
|
||||
/**
|
||||
* Fires `onTrigger` once whenever the provider's enabled API-key fingerprint OR
|
||||
* its host (endpoint/baseUrl/authType) changes — but only after the first render
|
||||
* and only when local models already exist (first-time bootstrap is owned by
|
||||
* `useProviderAutoModelSync`). A pull still requires at least one enabled key
|
||||
* for providers whose model sync needs API-key auth, so disabling the only key
|
||||
* never fires for those providers.
|
||||
* its host (endpoint/baseUrl/authType) changes. For API-key providers this also
|
||||
* fires on first render when no local models exist (first-time bootstrap uses
|
||||
* the pull-reconcile sidebar instead of direct auto-sync). A pull still requires
|
||||
* at least one enabled key for providers whose model sync needs API-key auth,
|
||||
* so disabling the only key never fires for those providers.
|
||||
*/
|
||||
export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => void | Promise<void>) {
|
||||
const { provider } = useProvider(providerId)
|
||||
const { data: apiKeysData } = useProviderApiKeys(providerId)
|
||||
const { models } = useModels({ providerId })
|
||||
const { models, isLoading } = useModels({ providerId })
|
||||
|
||||
const enabledKeySignature = useMemo(
|
||||
() =>
|
||||
@@ -47,12 +47,12 @@ export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => v
|
||||
}, [onTrigger])
|
||||
|
||||
useEffect(() => {
|
||||
// Until provider/api-keys resolve the signature is a cold-cache placeholder;
|
||||
// recording that as the baseline would make the later undefined→loaded
|
||||
// transition look like a user-initiated change and auto-fire the pull.
|
||||
if (!provider || apiKeysData === undefined) return
|
||||
if (!provider || apiKeysData === undefined || isLoading) return
|
||||
if (lastSignatureRef.current === null) {
|
||||
lastSignatureRef.current = changeSignature
|
||||
if (models.length === 0 && requiresApiKeyForModelSync && enabledKeySignature) {
|
||||
void onTriggerRef.current()
|
||||
}
|
||||
return
|
||||
}
|
||||
if (lastSignatureRef.current === changeSignature) {
|
||||
@@ -61,7 +61,15 @@ export function useAutoPullOnApiKeyChange(providerId: string, onTrigger: () => v
|
||||
lastSignatureRef.current = changeSignature
|
||||
// Key-required providers still need an enabled key; disabling the only key must not fire.
|
||||
if (requiresApiKeyForModelSync && !enabledKeySignature) return
|
||||
if (models.length === 0) return
|
||||
if (models.length === 0 && !requiresApiKeyForModelSync) return
|
||||
void onTriggerRef.current()
|
||||
}, [apiKeysData, changeSignature, enabledKeySignature, models.length, provider, requiresApiKeyForModelSync])
|
||||
}, [
|
||||
apiKeysData,
|
||||
changeSignature,
|
||||
enabledKeySignature,
|
||||
isLoading,
|
||||
models.length,
|
||||
provider,
|
||||
requiresApiKeyForModelSync
|
||||
])
|
||||
}
|
||||
|
||||
@@ -80,47 +80,58 @@ describe('useProviderAutoModelSync', () => {
|
||||
expect(useProviderModelSyncMock).toHaveBeenCalledWith('openai', { existingModels: [] })
|
||||
})
|
||||
|
||||
it('enables a disabled provider when auto sync returns at least one model', async () => {
|
||||
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'openai::gpt-4o' }])
|
||||
|
||||
it('skips auto sync for API-key providers (uses pull reconcile instead)', async () => {
|
||||
renderHook(() => useProviderAutoModelSync('openai'))
|
||||
|
||||
await waitFor(() => expect(updateProviderMock).toHaveBeenCalledWith({ isEnabled: true }))
|
||||
await waitFor(() =>
|
||||
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
|
||||
providerId: 'openai',
|
||||
reason: 'uses_pull_reconcile'
|
||||
})
|
||||
)
|
||||
expect(syncProviderModelsMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps a disabled provider disabled when auto sync returns zero models', async () => {
|
||||
syncProviderModelsMock.mockResolvedValueOnce([])
|
||||
|
||||
renderHook(() => useProviderAutoModelSync('openai'))
|
||||
|
||||
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
|
||||
await Promise.resolve()
|
||||
expect(updateProviderMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not patch an already enabled provider after successful auto sync', async () => {
|
||||
it('auto syncs for non-key providers (e.g. Ollama) when models are missing', async () => {
|
||||
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'ollama::llama3.2' }])
|
||||
useProviderMock.mockReturnValue({
|
||||
provider: {
|
||||
id: 'openai',
|
||||
isEnabled: true,
|
||||
defaultChatEndpoint: 'openai_chat_completions',
|
||||
id: 'ollama',
|
||||
isEnabled: false,
|
||||
defaultChatEndpoint: 'ollama_chat',
|
||||
endpointConfigs: {
|
||||
openai_chat_completions: { baseUrl: 'https://api.openai.com/v1' }
|
||||
ollama_chat: { baseUrl: 'http://localhost:11434' }
|
||||
}
|
||||
},
|
||||
updateProvider: updateProviderMock
|
||||
})
|
||||
syncProviderModelsMock.mockResolvedValueOnce([{ id: 'openai::gpt-4o' }])
|
||||
useProviderApiKeysMock.mockReturnValue({
|
||||
data: { keys: [] }
|
||||
})
|
||||
|
||||
renderHook(() => useProviderAutoModelSync('openai'))
|
||||
renderHook(() => useProviderAutoModelSync('ollama'))
|
||||
|
||||
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
|
||||
await Promise.resolve()
|
||||
expect(updateProviderMock).not.toHaveBeenCalled()
|
||||
await waitFor(() => expect(updateProviderMock).toHaveBeenCalledWith({ isEnabled: true }))
|
||||
})
|
||||
|
||||
it('syncs only once for the same initial eligible configuration', async () => {
|
||||
const { rerender } = renderHook(() => useProviderAutoModelSync('openai'))
|
||||
it('syncs only once for the same initial eligible configuration (non-key provider)', async () => {
|
||||
useProviderMock.mockReturnValue({
|
||||
provider: {
|
||||
id: 'ollama',
|
||||
isEnabled: false,
|
||||
defaultChatEndpoint: 'ollama_chat',
|
||||
endpointConfigs: {
|
||||
ollama_chat: { baseUrl: 'http://localhost:11434' }
|
||||
}
|
||||
},
|
||||
updateProvider: updateProviderMock
|
||||
})
|
||||
useProviderApiKeysMock.mockReturnValue({
|
||||
data: { keys: [] }
|
||||
})
|
||||
syncProviderModelsMock.mockResolvedValue([])
|
||||
|
||||
const { rerender } = renderHook(() => useProviderAutoModelSync('ollama'))
|
||||
|
||||
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(1))
|
||||
|
||||
@@ -154,19 +165,18 @@ describe('useProviderAutoModelSync', () => {
|
||||
expect(syncProviderModelsMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('logs auto sync failures and allows retrying when the same signature becomes eligible again', async () => {
|
||||
const syncError = new Error('sync down')
|
||||
syncProviderModelsMock.mockRejectedValueOnce(syncError).mockResolvedValueOnce([])
|
||||
|
||||
it('skips auto sync for API-key provider even after key rotation (pull reconcile handles it)', async () => {
|
||||
const { rerender } = renderHook(() => useProviderAutoModelSync('openai'))
|
||||
|
||||
// First render with keys → uses_pull_reconcile
|
||||
await waitFor(() =>
|
||||
expect(loggerErrorMock).toHaveBeenCalledWith('Provider auto model sync failed', {
|
||||
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
|
||||
providerId: 'openai',
|
||||
error: syncError
|
||||
reason: 'uses_pull_reconcile'
|
||||
})
|
||||
)
|
||||
|
||||
// Keys removed
|
||||
useProviderApiKeysMock.mockReturnValue({
|
||||
data: { keys: [] }
|
||||
})
|
||||
@@ -179,12 +189,18 @@ describe('useProviderAutoModelSync', () => {
|
||||
})
|
||||
)
|
||||
|
||||
// Keys restored — still uses pull reconcile, no direct sync
|
||||
useProviderApiKeysMock.mockReturnValue({
|
||||
data: { keys: [{ id: 'key-1', key: 'sk-test', isEnabled: true }] }
|
||||
})
|
||||
rerender()
|
||||
|
||||
await waitFor(() => expect(syncProviderModelsMock).toHaveBeenCalledTimes(2))
|
||||
expect(updateProviderMock).not.toHaveBeenCalled()
|
||||
await waitFor(() =>
|
||||
expect(loggerInfoMock).toHaveBeenCalledWith('Skipping provider auto model sync', {
|
||||
providerId: 'openai',
|
||||
reason: 'uses_pull_reconcile'
|
||||
})
|
||||
)
|
||||
expect(syncProviderModelsMock).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -69,6 +69,13 @@ export function useProviderAutoModelSync(providerId: string) {
|
||||
} as const
|
||||
}
|
||||
|
||||
if (requiresApiKeyForModelSync) {
|
||||
return {
|
||||
shouldSync: false,
|
||||
reason: 'uses_pull_reconcile'
|
||||
} as const
|
||||
}
|
||||
|
||||
if (!initialModelSyncSignature) {
|
||||
return {
|
||||
shouldSync: false,
|
||||
|
||||
@@ -30,12 +30,15 @@ import {
|
||||
} from '@cherrystudio/ui'
|
||||
import { loggerService } from '@logger'
|
||||
import ListItem from '@renderer/components/ListItem'
|
||||
import { WorkspaceSelector } from '@renderer/components/resource'
|
||||
import Scrollbar from '@renderer/components/Scrollbar'
|
||||
import { dataApiService } from '@renderer/data/DataApiService'
|
||||
import { useQuery } from '@renderer/data/hooks/useDataApi'
|
||||
import { useChannels } from '@renderer/hooks/agent/useChannels'
|
||||
import { useCreateTask, useDeleteTask, useRunTask, useTaskLogs, useUpdateTask } from '@renderer/hooks/agent/useTasks'
|
||||
import { useConversationNavigation } from '@renderer/hooks/useConversationNavigation'
|
||||
import { useTheme } from '@renderer/hooks/useTheme'
|
||||
import { AGENT_WORKSPACE_TYPE } from '@shared/data/api/schemas/agentWorkspaces'
|
||||
import type { Trigger } from '@shared/data/api/schemas/jobs'
|
||||
import type {
|
||||
AgentEntity,
|
||||
@@ -47,8 +50,11 @@ import type {
|
||||
import {
|
||||
AlertTriangle,
|
||||
CalendarClock,
|
||||
ChevronDown,
|
||||
CircleSlash,
|
||||
Clock,
|
||||
ExternalLink,
|
||||
Folder,
|
||||
History,
|
||||
Maximize2,
|
||||
MoreHorizontal,
|
||||
@@ -324,6 +330,15 @@ const TaskDetail: FC<{
|
||||
timeoutMinutes: task.timeoutMinutes?.toString() ?? ''
|
||||
})
|
||||
const [channelIds, setChannelIds] = useState<string[]>(task.channelIds ?? [])
|
||||
const [workspaceId, setWorkspaceId] = useState<string | null>(
|
||||
task.workspace.type === AGENT_WORKSPACE_TYPE.USER ? task.workspace.workspaceId : null
|
||||
)
|
||||
const { data: workspaces } = useQuery('/agent-workspaces')
|
||||
|
||||
const isSystemWorkspace = workspaceId === null
|
||||
const workspaceLabel = isSystemWorkspace
|
||||
? t('agent.session.workspace_selector.no_project')
|
||||
: (workspaces?.find((w) => w.id === workspaceId)?.name ?? workspaceId)
|
||||
|
||||
const toggleStatusLabel =
|
||||
task.status === 'active' ? t('agent.cherryClaw.tasks.pause') : t('agent.cherryClaw.tasks.resume')
|
||||
@@ -337,6 +352,7 @@ const TaskDetail: FC<{
|
||||
const next = triggerToFormState(task.trigger)
|
||||
setSchedule({ ...next, timeoutMinutes: task.timeoutMinutes?.toString() ?? '' })
|
||||
setChannelIds(task.channelIds ?? [])
|
||||
setWorkspaceId(task.workspace.type === AGENT_WORKSPACE_TYPE.USER ? task.workspace.workspaceId : null)
|
||||
}, [task])
|
||||
|
||||
const saveField = useCallback(
|
||||
@@ -528,6 +544,36 @@ const TaskDetail: FC<{
|
||||
}}
|
||||
disabled={isCompleted}
|
||||
/>
|
||||
|
||||
{/* Workspace is a secondary detail — scheduled tasks default to "No work directory". */}
|
||||
<div className="flex items-center gap-1.5 text-foreground-muted text-xs">
|
||||
<span>{t('agent.session.display.workdir')}</span>
|
||||
<WorkspaceSelector
|
||||
value={workspaceId}
|
||||
onChange={(nextWorkspaceId) => {
|
||||
setWorkspaceId(nextWorkspaceId)
|
||||
saveField({
|
||||
workspace:
|
||||
nextWorkspaceId === null
|
||||
? { type: AGENT_WORKSPACE_TYPE.SYSTEM }
|
||||
: { type: AGENT_WORKSPACE_TYPE.USER, workspaceId: nextWorkspaceId }
|
||||
})
|
||||
}}
|
||||
disabled={isCompleted}
|
||||
align="start"
|
||||
trigger={
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 gap-1 px-1.5 text-foreground-muted"
|
||||
disabled={isCompleted}>
|
||||
{isSystemWorkspace ? <CircleSlash className="size-3.5" /> : <Folder className="size-3.5" />}
|
||||
<span className="max-w-40 truncate">{workspaceLabel}</span>
|
||||
<ChevronDown className="size-3.5" />
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</SettingGroup>
|
||||
|
||||
@@ -783,14 +829,20 @@ const CreateForm: FC<{
|
||||
const [promptModalOpen, setPromptModalOpen] = useState(false)
|
||||
const [schedule, setSchedule] = useState<ScheduleFormState>({ kind: 'interval', value: '', timeoutMinutes: '' })
|
||||
const [channelIds, setChannelIds] = useState<string[]>([])
|
||||
// TODO(agent-workspace-picker): wire the workspace picker before re-enabling task creation.
|
||||
const [workspaceSource] = useState<CreateTaskRequest['workspace'] | null>(null)
|
||||
// `null` = "No work directory" (system workspace); a string binds the task to that user workspace.
|
||||
const [workspaceId, setWorkspaceId] = useState<string | null>(null)
|
||||
const { data: workspaces } = useQuery('/agent-workspaces')
|
||||
const [saving, setSaving] = useState(false)
|
||||
|
||||
const isValid = agentId && name.trim() && prompt.trim() && schedule.value.trim() && workspaceSource
|
||||
const isSystemWorkspace = workspaceId === null
|
||||
const workspaceLabel = isSystemWorkspace
|
||||
? t('agent.session.workspace_selector.no_project')
|
||||
: (workspaces?.find((w) => w.id === workspaceId)?.name ?? workspaceId)
|
||||
|
||||
const isValid = agentId && name.trim() && prompt.trim() && schedule.value.trim()
|
||||
|
||||
const handleCreate = useCallback(async () => {
|
||||
if (!agentId || !name.trim() || !prompt.trim() || !schedule.value.trim() || !workspaceSource) return
|
||||
if (!agentId || !name.trim() || !prompt.trim() || !schedule.value.trim()) return
|
||||
const trigger = formStateToTrigger(schedule.kind, schedule.value.trim())
|
||||
if (!trigger) return
|
||||
setSaving(true)
|
||||
@@ -800,14 +852,17 @@ const CreateForm: FC<{
|
||||
name: name.trim(),
|
||||
prompt: prompt.trim(),
|
||||
trigger,
|
||||
workspace: workspaceSource,
|
||||
workspace:
|
||||
workspaceId === null
|
||||
? { type: AGENT_WORKSPACE_TYPE.SYSTEM }
|
||||
: { type: AGENT_WORKSPACE_TYPE.USER, workspaceId },
|
||||
timeoutMinutes: timeout && timeout > 0 ? timeout : undefined,
|
||||
channelIds: channelIds.length > 0 ? channelIds : undefined
|
||||
})
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
}, [agentId, name, prompt, schedule, workspaceSource, channelIds, onCreate])
|
||||
}, [agentId, name, prompt, schedule, workspaceId, channelIds, onCreate])
|
||||
|
||||
return (
|
||||
<SettingsContentColumn theme={theme}>
|
||||
@@ -880,6 +935,23 @@ const CreateForm: FC<{
|
||||
<TaskScheduleControls value={schedule} onChange={setSchedule} />
|
||||
<TaskChannelSelector channels={channels} channelIds={channelIds} onChange={setChannelIds} />
|
||||
|
||||
{/* Workspace is a secondary detail — scheduled tasks default to "No work directory". */}
|
||||
<div className="flex items-center gap-1.5 text-foreground-muted text-xs">
|
||||
<span>{t('agent.session.display.workdir')}</span>
|
||||
<WorkspaceSelector
|
||||
value={workspaceId}
|
||||
onChange={setWorkspaceId}
|
||||
align="start"
|
||||
trigger={
|
||||
<Button variant="ghost" size="sm" className="h-6 gap-1 px-1.5 text-foreground-muted">
|
||||
{isSystemWorkspace ? <CircleSlash className="size-3.5" /> : <Folder className="size-3.5" />}
|
||||
<span className="max-w-40 truncate">{workspaceLabel}</span>
|
||||
<ChevronDown className="size-3.5" />
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={onCancel}>
|
||||
{t('agent.cherryClaw.tasks.cancel')}
|
||||
|
||||
@@ -71,6 +71,14 @@ vi.mock('@renderer/hooks/agent/useChannels', () => ({
|
||||
useChannels: () => ({ channels: [] })
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/data/hooks/useDataApi', () => ({
|
||||
useQuery: () => ({ data: [] })
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/components/resource', () => ({
|
||||
WorkspaceSelector: ({ trigger }: { trigger: React.ReactNode }) => <>{trigger}</>
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/agent/useTasks', () => ({
|
||||
useCreateTask: () => ({ createTask: taskMutationMocks.createTask }),
|
||||
useDeleteTask: () => ({ deleteTask: taskMutationMocks.deleteTask }),
|
||||
|
||||
@@ -6,6 +6,8 @@ import {
|
||||
KB_SEARCH_TOOL_NAME,
|
||||
kbListInputSchema,
|
||||
kbListStrictInputSchema,
|
||||
kbManageInputSchema,
|
||||
kbManageStrictInputSchema,
|
||||
kbSearchInputSchema,
|
||||
REPORT_ARTIFACTS_DESCRIPTION,
|
||||
REPORT_ARTIFACTS_TOOL_NAME,
|
||||
@@ -62,6 +64,41 @@ describe('builtin tool contracts', () => {
|
||||
expect(kbListInputSchema.safeParse({ query: 'recipes' }).success).toBe(true)
|
||||
})
|
||||
|
||||
it('keeps kb_manage strict-path fields in `required` so strict providers accept the schema', () => {
|
||||
// Same regression as kb_list above: the AI-SDK path (KnowledgeManageTool) runs strict:true, so
|
||||
// an all-optional object would serialize `required` away to nothing and a strict OpenAI-compatible
|
||||
// provider would reject the whole request. The strict variant makes every optional field
|
||||
// `.nullable()` (null = unused for this action/type) so they all stay in `required`.
|
||||
const json = z.toJSONSchema(kbManageStrictInputSchema) as { required?: unknown }
|
||||
|
||||
expect(Array.isArray(json.required)).toBe(true)
|
||||
expect(json.required).toEqual(
|
||||
expect.arrayContaining(['baseId', 'action', 'type', 'path', 'url', 'content', 'title', 'conceptIds'])
|
||||
)
|
||||
// null is the "unused" signal for every optional field; an explicit all-null payload must still parse.
|
||||
expect(
|
||||
kbManageStrictInputSchema.safeParse({
|
||||
baseId: 'kb-1',
|
||||
action: 'delete',
|
||||
type: null,
|
||||
path: null,
|
||||
url: null,
|
||||
content: null,
|
||||
title: null,
|
||||
conceptIds: null
|
||||
}).success
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('lets the MCP kb_manage path omit unused fields', () => {
|
||||
// The Claude Code bridge parses raw args with kbManageInputSchema; an agent may omit every
|
||||
// field but `baseId`/`action`, so the optional shape must accept that without erroring.
|
||||
expect(kbManageInputSchema.safeParse({ baseId: 'kb-1', action: 'delete' }).success).toBe(true)
|
||||
expect(kbManageInputSchema.safeParse({ baseId: 'kb-1', action: 'add', type: 'note', content: 'hi' }).success).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('validates final report artifacts', () => {
|
||||
const result = reportArtifactsInputSchema.parse({
|
||||
artifacts: [{ path: 'dist/report.pdf', description: 'Final report' }],
|
||||
|
||||
@@ -279,6 +279,11 @@ export const KB_MANAGE_ADD_TYPES = ['file', 'url', 'note'] as const
|
||||
// One flat object, not a discriminated union: which fields apply depends on `action`
|
||||
// (and, for add, on `type`). The core validates the combination and returns a steer
|
||||
// string on a missing field, so the model gets a clear error rather than a schema reject.
|
||||
//
|
||||
// kb_manage is consumed by two paths with conflicting schema needs, same split as kb_list.
|
||||
//
|
||||
// MCP / Claude Code bridge (cherryBuiltinTools): the agent parses raw args with this schema and may
|
||||
// omit any field, so they are `.optional()`.
|
||||
export const kbManageInputSchema = z.object({
|
||||
baseId: z.string().trim().min(1).describe('ID of the knowledge base to modify — a base id from kb_list.'),
|
||||
action: z
|
||||
@@ -319,6 +324,59 @@ export const kbManageInputSchema = z.object({
|
||||
)
|
||||
})
|
||||
|
||||
// AI-SDK path (KnowledgeManageTool) runs with `strict: true` — same `.nullable()` treatment as
|
||||
// `kbListStrictInputSchema` and for the same reason (an all-optional shape serializes `required`
|
||||
// away to nothing, which a strict OpenAI-compatible provider rejects). `manageKnowledge` treats
|
||||
// null (like undefined) as "field not set for this action/type".
|
||||
export const kbManageStrictInputSchema = z.object({
|
||||
baseId: z.string().trim().min(1).describe('ID of the knowledge base to modify — a base id from kb_list.'),
|
||||
action: z
|
||||
.enum(KB_MANAGE_ACTIONS)
|
||||
.describe(
|
||||
'add: import a new source (set `type` + its field). delete: remove documents by `conceptIds`. ' +
|
||||
'refresh: re-index documents by `conceptIds`. All actions modify the base and require user approval.'
|
||||
),
|
||||
type: z
|
||||
.enum(KB_MANAGE_ADD_TYPES)
|
||||
.nullable()
|
||||
.describe(
|
||||
'For action="add" only: the source kind — "file" (set `path`), "url" (set `url`), or "note" (set `content`). ' +
|
||||
'Pass null otherwise.'
|
||||
),
|
||||
path: z
|
||||
.string()
|
||||
.trim()
|
||||
.min(1)
|
||||
.nullable()
|
||||
.describe('For action="add", type="file": absolute local filesystem path of the file to import. Else null.'),
|
||||
url: z
|
||||
.string()
|
||||
.trim()
|
||||
.min(1)
|
||||
.nullable()
|
||||
.describe('For action="add", type="url": the URL to fetch and index. Else null.'),
|
||||
content: z
|
||||
.string()
|
||||
.min(1)
|
||||
.nullable()
|
||||
.describe('For action="add", type="note": the plain-text note content to index. Else null.'),
|
||||
title: z
|
||||
.string()
|
||||
.trim()
|
||||
.min(1)
|
||||
.nullable()
|
||||
.describe(
|
||||
'For action="add", type="note": optional display title (defaults to the note\'s first line). Pass null to omit.'
|
||||
),
|
||||
conceptIds: z
|
||||
.array(z.string().trim().min(1))
|
||||
.nullable()
|
||||
.describe(
|
||||
'For action="delete"/"refresh": Concept IDs (the `conceptId` field of a kb_search hit or a kb_list result) ' +
|
||||
'to operate on. Else null.'
|
||||
)
|
||||
})
|
||||
|
||||
export const kbManageOutputSchema = z.object({
|
||||
action: z.enum(KB_MANAGE_ACTIONS),
|
||||
// add: the source identifiers that were imported (one per add call).
|
||||
@@ -330,7 +388,6 @@ export const kbManageOutputSchema = z.object({
|
||||
notFound: z.array(z.string()).optional()
|
||||
})
|
||||
|
||||
export type KbManageInput = z.infer<typeof kbManageInputSchema>
|
||||
export type KbManageOutput = z.infer<typeof kbManageOutputSchema>
|
||||
|
||||
// ── web_search ───────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user