diff --git a/.gitignore b/.gitignore index f2b1c4c7f8..a3664d9e7b 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ out mcp_server stats.html .eslintcache +resources/binaries/ # ENV .env diff --git a/package.json b/package.json index 525bbec452..08a79b7122 100644 --- a/package.json +++ b/package.json @@ -72,7 +72,6 @@ "bench:renderer": "vitest bench --run --project renderer", "bench:aicore": "vitest bench --run --project aiCore", "bench:shared": "vitest bench --run --project shared", - "postinstall": "tsx scripts/patch-claude-agent-sdk.ts", "prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && prek install", "claude": "dotenv -e .env -- claude", "db:migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts", diff --git a/scripts/__tests__/patch-claude-agent-sdk.test.ts b/scripts/__tests__/patch-claude-agent-sdk.test.ts deleted file mode 100644 index 47a3f2e313..0000000000 --- a/scripts/__tests__/patch-claude-agent-sdk.test.ts +++ /dev/null @@ -1,488 +0,0 @@ -/** - * Unit tests for the patch-claude-agent-sdk.ts postinstall script. - * - * The patch functions are exported from the script and imported directly, - * so tests always exercise the real implementation. - */ - -import { describe, expect, it } from 'vitest' - -import { - applyAllPatches, - patchRemoveCommand as applyPatch2, - patchSpawnCall as applyPatch3, - patchSpawnImport as applyPatch1 -} from '../patch-claude-agent-sdk' - -// --------------------------------------------------------------------------- -// Shared fixture helpers -// --------------------------------------------------------------------------- - -/** - * Build a minimal realistic snippet of minified SDK code with configurable - * variable names so we can exercise different obfuscation scenarios. - */ -function buildSdkSnippet({ - spawnAlias = 'Sq', - fnArg = 'Q', - cmdVar = 'X', - argsVar = 'Y', - cwdVar = '$', - envVar = 'W', - sigVar = 'J', - stderrVar = 'G', - extraBefore = '', - extraAfter = '' -}: { - spawnAlias?: string - fnArg?: string - cmdVar?: string - argsVar?: string - cwdVar?: string - envVar?: string - sigVar?: string - stderrVar?: string - extraBefore?: string - extraAfter?: string -} = {}): string { - return [ - extraBefore, - `import{spawn as ${spawnAlias}}from"child_process"`, - `spawnLocalProcess(${fnArg}){let{command:${cmdVar},args:${argsVar},cwd:${cwdVar},env:${envVar},signal:${sigVar}}=${fnArg}`, - `=${spawnAlias}(${cmdVar},${argsVar},{cwd:${cwdVar},stdio:["pipe","pipe",${stderrVar}],signal:${sigVar},env:${envVar},windowsHide:!0}`, - extraAfter - ] - .filter(Boolean) - .join('\n') -} - -// --------------------------------------------------------------------------- -// Patch 1 – spawn → fork import -// --------------------------------------------------------------------------- - -describe('Patch 1: spawn → fork import replacement', () => { - it('replaces spawn with fork keeping the same alias', () => { - const input = `import{spawn as Sq}from"child_process"` - const { result, matched } = applyPatch1(input) - - expect(matched).toBe(true) - expect(result).toBe(`import{fork as Sq}from"child_process"`) - }) - - it('preserves a single-letter alias', () => { - const input = `import{spawn as X}from"child_process"` - const { result, matched } = applyPatch1(input) - - expect(matched).toBe(true) - expect(result).toBe(`import{fork as X}from"child_process"`) - }) - - it('preserves an underscore-prefixed alias', () => { - const input = `import{spawn as _spawn}from"child_process"` - const { result, matched } = applyPatch1(input) - - expect(matched).toBe(true) - expect(result).toBe(`import{fork as _spawn}from"child_process"`) - }) - - it('does not match when spawn is already replaced by fork', () => { - const input = `import{fork as Sq}from"child_process"` - const { matched } = applyPatch1(input) - - expect(matched).toBe(false) - }) - - it('does not match unrelated child_process imports', () => { - const input = `import{exec as Sq}from"child_process"` - const { matched } = applyPatch1(input) - - expect(matched).toBe(false) - }) - - it('does not match when double-quotes are replaced by single-quotes', () => { - const input = `import{spawn as Sq}from'child_process'` - const { matched } = applyPatch1(input) - - expect(matched).toBe(false) - }) - - it('only replaces the first occurrence (non-global regex)', () => { - const input = [`import{spawn as Sq}from"child_process"`, `import{spawn as Ab}from"child_process"`].join('\n') - const { result } = applyPatch1(input) - - // Only the first line should be changed - expect(result).toContain(`import{fork as Sq}from"child_process"`) - expect(result).toContain(`import{spawn as Ab}from"child_process"`) - }) -}) - -// --------------------------------------------------------------------------- -// Patch 2 – remove command: from destructuring -// --------------------------------------------------------------------------- - -describe('Patch 2: remove command variable from spawnLocalProcess destructuring', () => { - it('removes the command:VAR, segment with standard variable names', () => { - const input = `spawnLocalProcess(Q){let{command:X,args:Y,cwd:$,env:W,signal:J}=Q` - const { result, matched } = applyPatch2(input) - - expect(matched).toBe(true) - expect(result).toBe(`spawnLocalProcess(Q){let{args:Y,cwd:$,env:W,signal:J}=Q`) - expect(result).not.toContain('command:') - }) - - it('works when the function argument uses a dollar-sign variable', () => { - const input = `spawnLocalProcess($){let{command:X,args:Y` - const { result, matched } = applyPatch2(input) - - expect(matched).toBe(true) - expect(result).toContain(`spawnLocalProcess($){let{args:Y`) - }) - - it('works with single-character obfuscated names throughout', () => { - const input = `spawnLocalProcess(a){let{command:b,args:c` - const { result, matched } = applyPatch2(input) - - expect(matched).toBe(true) - expect(result).toContain(`spawnLocalProcess(a){let{args:c`) - }) - - it('works when the args variable uses a dollar-sign', () => { - const input = `spawnLocalProcess(Q){let{command:X,args:$` - const { result, matched } = applyPatch2(input) - - expect(matched).toBe(true) - expect(result).toContain(`let{args:$`) - }) - - it('does not match when command is already absent', () => { - const input = `spawnLocalProcess(Q){let{args:Y` - const { matched } = applyPatch2(input) - - expect(matched).toBe(false) - }) - - it('does not match unrelated destructuring patterns', () => { - const input = `someOtherFunction(Q){let{command:X,args:Y` - const { matched } = applyPatch2(input) - - expect(matched).toBe(false) - }) -}) - -// --------------------------------------------------------------------------- -// Patch 3 – rewrite spawn call to fork with IPC stdio -// --------------------------------------------------------------------------- - -describe('Patch 3: rewrite spawn call to use fork with IPC stdio', () => { - it('rewrites spawn call with standard variable names', () => { - const input = `=Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W,windowsHide:!0}` - const { result, matched } = applyPatch3(input) - - expect(matched).toBe(true) - expect(result).toBe( - `=Sq(Y[0],Y.slice(1),{cwd:$,stdio:G==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:J,env:W}` - ) - }) - - it('removes windowsHide:!0 from the output', () => { - const input = `=Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W,windowsHide:!0}` - const { result } = applyPatch3(input) - - expect(result).not.toContain('windowsHide') - }) - - it('uses args[0] as the module path for fork', () => { - const input = `=fn(cmd,args,{cwd:c,stdio:["pipe","pipe",s],signal:sig,env:e,windowsHide:!0}` - const { result, matched } = applyPatch3(input) - - expect(matched).toBe(true) - expect(result).toContain('args[0]') - expect(result).toContain('args.slice(1)') - }) - - it('produces conditional IPC stdio based on stderr variable', () => { - const input = `=fn(cmd,args,{cwd:c,stdio:["pipe","pipe",s],signal:sig,env:e,windowsHide:!0}` - const { result } = applyPatch3(input) - - expect(result).toContain(`s==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"]`) - }) - - it('works with dollar-sign variables', () => { - const input = `=$($,$$,{cwd:$$$,stdio:["pipe","pipe",$v],signal:$s,env:$e,windowsHide:!0}` - const { result, matched } = applyPatch3(input) - - expect(matched).toBe(true) - expect(result).toContain('$$[0]') - expect(result).toContain('$$.slice(1)') - }) - - it('does not match when windowsHide:!0 is absent (already patched)', () => { - const input = `=Sq(Y[0],Y.slice(1),{cwd:$,stdio:G==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:J,env:W}` - const { matched } = applyPatch3(input) - - expect(matched).toBe(false) - }) - - it('does not match patterns without the windowsHide flag', () => { - const input = `=Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W}` - const { matched } = applyPatch3(input) - - expect(matched).toBe(false) - }) -}) - -// --------------------------------------------------------------------------- -// Integration: all three patches applied together -// --------------------------------------------------------------------------- - -describe('applyAllPatches: full end-to-end patch application', () => { - it('applies all 3 patches to a canonical minified snippet', () => { - const input = buildSdkSnippet() - const { result, patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(3) - expect(result).toContain('import{fork as Sq}from"child_process"') - expect(result).not.toContain('command:X') - expect(result).toContain('Y[0]') - expect(result).toContain('Y.slice(1)') - expect(result).toContain('"ipc"') - expect(result).not.toContain('windowsHide') - }) - - it('applies all 3 patches when variables use uncommon names (Sq, Ab, $)', () => { - const input = buildSdkSnippet({ - spawnAlias: 'Ab', - fnArg: 'P', - cmdVar: 'c', - argsVar: 'a', - cwdVar: 'd', - envVar: 'e', - sigVar: 's', - stderrVar: '$' - }) - const { result, patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(3) - expect(result).toContain('import{fork as Ab}from"child_process"') - expect(result).toContain('a[0]') - expect(result).toContain('a.slice(1)') - }) - - it('applies all 3 patches with numeric-suffix alias (e.g. Fn2)', () => { - const input = buildSdkSnippet({ - spawnAlias: 'Fn2', - fnArg: 'r', - cmdVar: 'c', - argsVar: 'a', - cwdVar: 'w', - envVar: 'e', - sigVar: 's', - stderrVar: 'x' - }) - const { result, patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(3) - expect(result).toContain('import{fork as Fn2}from"child_process"') - }) - - it('applies all 3 patches when spawn alias uses a dollar-sign', () => { - const input = buildSdkSnippet({ spawnAlias: '$p' }) - const { patchCount, result } = applyAllPatches(input) - - expect(patchCount).toBe(3) - expect(result).toContain('import{fork as $p}from"child_process"') - }) - - it('applies patches correctly when surrounded by other minified code', () => { - const input = buildSdkSnippet({ - extraBefore: 'var a=1;function b(){return c}', - extraAfter: ';var z=42;' - }) - const { result, patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(3) - // Surrounding code must be preserved - expect(result).toContain('var a=1;function b(){return c}') - expect(result).toContain(';var z=42;') - }) -}) - -// --------------------------------------------------------------------------- -// Idempotency – running on already-patched content -// --------------------------------------------------------------------------- - -describe('Idempotency: re-running on already-patched content', () => { - it('detects already-patched content and returns patchCount=0 + alreadyPatched=true', () => { - const original = buildSdkSnippet() - const { result: patched } = applyAllPatches(original) - - // Second pass - const { patchCount, alreadyPatched } = applyAllPatches(patched) - - expect(patchCount).toBe(0) - expect(alreadyPatched).toBe(true) - }) - - it('does not double-apply patch 1 (fork import stays as fork)', () => { - const original = buildSdkSnippet({ spawnAlias: 'Fn' }) - const { result: firstPass } = applyAllPatches(original) - const { result: secondPass } = applyAllPatches(firstPass) - - // fork should not be turned into something else - expect(secondPass).toContain('import{fork as Fn}from"child_process"') - expect(secondPass).not.toContain('import{spawn as Fn}from"child_process"') - }) - - it('does not re-apply patch 2 (command stays absent)', () => { - const original = buildSdkSnippet() - const { result: firstPass } = applyAllPatches(original) - const { result: secondPass } = applyAllPatches(firstPass) - - expect(secondPass).not.toContain('command:') - }) - - it('does not re-apply patch 3 (windowsHide stays absent, IPC stays present)', () => { - const original = buildSdkSnippet() - const { result: firstPass } = applyAllPatches(original) - const { result: secondPass } = applyAllPatches(firstPass) - - expect(secondPass).not.toContain('windowsHide') - expect(secondPass).toContain('"ipc"') - }) -}) - -// --------------------------------------------------------------------------- -// Partial matches – only some patterns match -// --------------------------------------------------------------------------- - -describe('Partial matches: only subset of patterns match', () => { - it('returns patchCount=1 when only patch 1 matches', () => { - const input = `import{spawn as Sq}from"child_process"\n// no spawnLocalProcess here` - const { patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(1) - }) - - it('returns patchCount=2 when patches 1 and 2 match but not patch 3', () => { - const input = [ - `import{spawn as Sq}from"child_process"`, - `spawnLocalProcess(Q){let{command:X,args:Y`, - `// no spawn call with windowsHide` - ].join('\n') - const { patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(2) - }) - - it('returns patchCount=2 when patches 1 and 3 match but not patch 2', () => { - const input = [ - `import{spawn as Sq}from"child_process"`, - `// no spawnLocalProcess destructuring`, - `=Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W,windowsHide:!0}` - ].join('\n') - const { patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(2) - }) - - it('is NOT flagged as alreadyPatched when patchCount=0 but fork/ipc are absent', () => { - const input = `completely unrelated content without fork or ipc` - const { patchCount, alreadyPatched } = applyAllPatches(input) - - expect(patchCount).toBe(0) - expect(alreadyPatched).toBe(false) - }) -}) - -// --------------------------------------------------------------------------- -// No match – completely unrelated content -// --------------------------------------------------------------------------- - -describe('No match: unrelated content produces no patches', () => { - it('returns patchCount=0 for completely unrelated content', () => { - const input = `console.log("hello world");var x=42;` - const { patchCount, alreadyPatched } = applyAllPatches(input) - - expect(patchCount).toBe(0) - expect(alreadyPatched).toBe(false) - }) - - it('returns patchCount=0 for empty string', () => { - const { patchCount } = applyAllPatches('') - - expect(patchCount).toBe(0) - }) - - it('does not match import with single-quotes instead of double-quotes', () => { - const input = `import{spawn as Sq}from'child_process'` - const { patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(0) - }) - - it('does not match import with spaces around braces', () => { - const input = `import { spawn as Sq } from "child_process"` - const { patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(0) - }) - - it('does not produce a false alreadyPatched for content with only fork (no ipc)', () => { - const input = `import{fork as Sq}from"child_process"` - const { patchCount, alreadyPatched } = applyAllPatches(input) - - expect(patchCount).toBe(0) - // alreadyPatched requires BOTH 'import{fork as' AND '"ipc"' - expect(alreadyPatched).toBe(false) - }) - - it('does not produce a false alreadyPatched for content with only ipc (no fork import)', () => { - const input = `stdio:["pipe","pipe","pipe","ipc"]` - const { patchCount, alreadyPatched } = applyAllPatches(input) - - expect(patchCount).toBe(0) - expect(alreadyPatched).toBe(false) - }) -}) - -// --------------------------------------------------------------------------- -// Output correctness – verify exact transformed strings -// --------------------------------------------------------------------------- - -describe('Output correctness: verify exact replacement strings', () => { - it('patch 1 exact output matches expected string', () => { - const { result } = applyPatch1(`import{spawn as myAlias}from"child_process"`) - expect(result).toBe(`import{fork as myAlias}from"child_process"`) - }) - - it('patch 2 exact output: function arg and args var are preserved correctly', () => { - const { result } = applyPatch2(`spawnLocalProcess(P){let{command:C,args:A`) - expect(result).toBe(`spawnLocalProcess(P){let{args:A`) - }) - - it('patch 3 exact output: full spawn-to-fork rewrite is correct', () => { - const { result } = applyPatch3(`=fn(cmd,args,{cwd:c,stdio:["pipe","pipe",s],signal:sig,env:e,windowsHide:!0}`) - expect(result).toBe( - `=fn(args[0],args.slice(1),{cwd:c,stdio:s==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:sig,env:e}` - ) - }) - - it('full pipeline: canonical snippet transforms to expected patched form', () => { - const input = [ - `import{spawn as Sq}from"child_process"`, - `spawnLocalProcess(Q){let{command:X,args:Y,cwd:$,env:W,signal:J}=Q`, - `=Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W,windowsHide:!0}` - ].join('\n') - - const { result, patchCount } = applyAllPatches(input) - - expect(patchCount).toBe(3) - - const lines = result.split('\n') - expect(lines[0]).toBe(`import{fork as Sq}from"child_process"`) - expect(lines[1]).toBe(`spawnLocalProcess(Q){let{args:Y,cwd:$,env:W,signal:J}=Q`) - expect(lines[2]).toBe( - `=Sq(Y[0],Y.slice(1),{cwd:$,stdio:G==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:J,env:W}` - ) - }) -}) diff --git a/scripts/before-pack.js b/scripts/before-pack.js index e3f08816a6..30e9d981f4 100644 --- a/scripts/before-pack.js +++ b/scripts/before-pack.js @@ -57,6 +57,14 @@ exports.default = async function (context) { const platformName = context.packager.platform.name const platform = platformToArch[platformName] + // Download rtk binary for the target platform + try { + console.log(`Downloading rtk binary for ${platform}-${arch}...`) + execSync(`node "${path.join(__dirname, 'download-rtk-binaries.js')}" ${platform} ${arch}`, { stdio: 'inherit' }) + } catch (error) { + console.warn(`Warning: rtk binary download failed (non-fatal): ${error.message}`) + } + const downloadPackages = async () => { // Skip if target platform and architecture match current system if (platform === process.platform && arch === process.arch) { @@ -126,9 +134,16 @@ exports.default = async function (context) { }) .map((f) => '!node_modules/@anthropic-ai/claude-agent-sdk/vendor/ripgrep/' + f + '/**') + // Exclude rtk binaries for other platform-arch combinations + const currentPlatformKey = `${platform}-${arch}` + const allRtkPlatforms = ['darwin-arm64', 'darwin-x64', 'linux-x64', 'linux-arm64', 'win32-x64'] + const excludeRtkFilters = allRtkPlatforms + .filter((p) => p !== currentPlatformKey) + .map((p) => '!resources/binaries/' + p + '/**') + if (context.arch === Arch.arm64) { - await excludePackages([...arm64ExcludePackages, ...excludeRipgrepFilters]) + await excludePackages([...arm64ExcludePackages, ...excludeRipgrepFilters, ...excludeRtkFilters]) } else { - await excludePackages([...x64ExcludePackages, ...excludeRipgrepFilters]) + await excludePackages([...x64ExcludePackages, ...excludeRipgrepFilters, ...excludeRtkFilters]) } } diff --git a/scripts/download-rtk-binaries.js b/scripts/download-rtk-binaries.js new file mode 100644 index 0000000000..f6ea93d3cb --- /dev/null +++ b/scripts/download-rtk-binaries.js @@ -0,0 +1,92 @@ +/** + * Downloads rtk binary for the target platform during build. + * Called from before-pack.js to bundle the binary into resources/binaries/. + * + * Usage: + * node scripts/download-rtk-binaries.js + * e.g. node scripts/download-rtk-binaries.js darwin arm64 + */ +const fs = require('fs') +const path = require('path') +const os = require('os') +const { execFileSync } = require('child_process') + +const RTK_VERSION = '0.30.1' + +const RTK_PACKAGES = { + 'darwin-arm64': { file: 'rtk-aarch64-apple-darwin.tar.gz', binary: 'rtk' }, + 'darwin-x64': { file: 'rtk-x86_64-apple-darwin.tar.gz', binary: 'rtk' }, + 'linux-x64': { file: 'rtk-x86_64-unknown-linux-musl.tar.gz', binary: 'rtk' }, + 'linux-arm64': { file: 'rtk-aarch64-unknown-linux-gnu.tar.gz', binary: 'rtk' }, + 'win32-x64': { file: 'rtk-x86_64-pc-windows-msvc.zip', binary: 'rtk.exe' } +} + +function downloadFile(url, destPath) { + console.log(`Downloading: ${url}`) + execFileSync('curl', ['-fSL', '--retry', '3', '-o', destPath, url], { stdio: 'inherit' }) + if (!fs.existsSync(destPath)) { + throw new Error(`Download failed: ${destPath} not found`) + } +} + +function downloadRtk(platformKey, outputDir) { + const pkg = RTK_PACKAGES[platformKey] + if (!pkg) { + console.warn(`[rtk] No binary available for ${platformKey}, skipping`) + return + } + + const url = `https://github.com/rtk-ai/rtk/releases/download/v${RTK_VERSION}/${pkg.file}` + const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'rtk-')) + const tempFile = path.join(tempDir, pkg.file) + + try { + downloadFile(url, tempFile) + + if (pkg.file.endsWith('.tar.gz')) { + execFileSync('tar', ['-xzf', tempFile, '-C', tempDir], { stdio: 'inherit' }) + } else if (pkg.file.endsWith('.zip')) { + execFileSync('unzip', ['-o', tempFile, '-d', tempDir], { stdio: 'inherit' }) + } + + // rtk archives extract the binary at the root level + const srcPath = path.join(tempDir, pkg.binary) + if (!fs.existsSync(srcPath)) { + throw new Error(`rtk binary '${pkg.binary}' not found in extracted archive`) + } + + const destPath = path.join(outputDir, pkg.binary) + fs.copyFileSync(srcPath, destPath) + if (process.platform !== 'win32') { + fs.chmodSync(destPath, 0o755) + } + console.log(`[rtk] Installed ${pkg.binary} to ${destPath}`) + } finally { + fs.rmSync(tempDir, { recursive: true, force: true }) + } +} + +function main() { + const platform = process.argv[2] || process.platform + const arch = process.argv[3] || process.arch + const platformKey = `${platform}-${arch}` + + console.log(`Downloading rtk binary for ${platformKey}...`) + + const outputDir = path.join(__dirname, '..', 'resources', 'binaries', platformKey) + fs.mkdirSync(outputDir, { recursive: true }) + + downloadRtk(platformKey, outputDir) + + // Write version file for upgrade detection at runtime + fs.writeFileSync(path.join(outputDir, '.rtk-version'), RTK_VERSION, 'utf8') + + console.log(`All binaries downloaded to ${outputDir}`) +} + +try { + main() +} catch (error) { + console.error('Failed to download binaries:', error.message) + // Non-fatal: don't block the build if binary download fails +} diff --git a/scripts/patch-claude-agent-sdk.ts b/scripts/patch-claude-agent-sdk.ts deleted file mode 100644 index a7a4439a30..0000000000 --- a/scripts/patch-claude-agent-sdk.ts +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Postinstall script to patch @anthropic-ai/claude-agent-sdk - * - * The SDK is shipped as minified/obfuscated code, so we use semantic regex - * patterns (not variable names) to apply patches. This is more resilient - * to SDK version bumps than a static .patch file. - * - * Changes: - * 1. spawn → fork (child_process import) — enables IPC channel - * 2. Remove `command` from spawnLocalProcess destructuring - * 3. Rewrite spawn call to use fork(args[0], args.slice(1), ...) with IPC stdio - */ - -import { readFileSync, writeFileSync } from 'node:fs' -import { createRequire } from 'node:module' -import path from 'node:path' - -interface PatchResult { - result: string - matched: boolean -} - -interface ApplyAllResult { - result: string - patchCount: number - alreadyPatched: boolean -} - -// 1. Replace `import{spawn as X}from"child_process"` with `import{fork as X}from"child_process"` -export function patchSpawnImport(content: string): PatchResult { - let matched = false - const result = content.replace(/import\{spawn as ([\w$]+)\}from"child_process"/, (_, alias) => { - matched = true - return `import{fork as ${alias}}from"child_process"` - }) - return { result, matched } -} - -// 2. Remove `command:X,` from spawnLocalProcess destructuring -// Before: spawnLocalProcess(Q){let{command:X,args:Y,cwd:$,env:W,signal:J}=Q -// After: spawnLocalProcess(Q){let{args:Y,cwd:$,env:W,signal:J}=Q -export function patchRemoveCommand(content: string): PatchResult { - let matched = false - const result = content.replace( - /spawnLocalProcess\(([\w$]+)\)\{let\{command:([\w$]+),args:([\w$]+)/, - (_, fnArg, _cmd, args) => { - matched = true - return `spawnLocalProcess(${fnArg}){let{args:${args}` - } - ) - return { result, matched } -} - -// 3. Rewrite the spawn/fork call: -// Before: =Sq(X,Y,{cwd:$,stdio:["pipe","pipe",G],signal:J,env:W,windowsHide:!0}) -// After: =Sq(Y[0],Y.slice(1),{cwd:$,stdio:G==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:J,env:W}) -export function patchSpawnCall(content: string): PatchResult { - let matched = false - const result = content.replace( - /([\w$]+)\(([\w$]+),([\w$]+),\{cwd:([\w$]+),stdio:\["pipe","pipe",([\w$]+)\],signal:([\w$]+),env:([\w$]+),windowsHide:!0\}/, - (_, fn, _cmd, args, cwd, stderr, signal, env) => { - matched = true - return `${fn}(${args}[0],${args}.slice(1),{cwd:${cwd},stdio:${stderr}==="pipe"?["pipe","pipe","pipe","ipc"]:["pipe","pipe","ignore","ipc"],signal:${signal},env:${env}}` - } - ) - return { result, matched } -} - -// Apply all patches and return summary -export function applyAllPatches(content: string): ApplyAllResult { - let patchCount = 0 - - const p1 = patchSpawnImport(content) - content = p1.result - if (p1.matched) patchCount++ - - const p2 = patchRemoveCommand(content) - content = p2.result - if (p2.matched) patchCount++ - - const p3 = patchSpawnCall(content) - content = p3.result - if (p3.matched) patchCount++ - - const alreadyPatched = patchCount === 0 && content.includes('import{fork as') && content.includes('"ipc"') - - return { result: content, patchCount, alreadyPatched } -} - -// --- CLI entry point (skipped when imported by tests) --- - -function main() { - const require_ = createRequire(import.meta.url) - - let sdkPath: string - try { - sdkPath = path.join(path.dirname(require_.resolve('@anthropic-ai/claude-agent-sdk')), 'sdk.mjs') - } catch { - console.log('[patch-claude-agent-sdk] Package not installed, skipping.') - process.exit(0) - } - - let fileContent: string - try { - fileContent = readFileSync(sdkPath, 'utf-8') - } catch { - console.error(`[patch-claude-agent-sdk] Failed to read ${sdkPath}`) - process.exit(1) - } - - const { result, patchCount, alreadyPatched } = applyAllPatches(fileContent) - - if (patchCount === 0) { - if (alreadyPatched) { - console.log('[patch-claude-agent-sdk] Already patched, skipping.') - process.exit(0) - } - console.error('[patch-claude-agent-sdk] No patterns matched! The SDK structure may have changed.') - process.exit(1) - } - - if (patchCount < 3) { - console.warn(`[patch-claude-agent-sdk] Warning: only ${patchCount}/3 patches applied.`) - } - - writeFileSync(sdkPath, result, 'utf-8') - console.log(`[patch-claude-agent-sdk] Successfully applied ${patchCount}/3 patches to sdk.mjs`) -} - -if (!process.env.VITEST) { - main() -} diff --git a/src/main/index.ts b/src/main/index.ts index 7ff4f27fc4..5b55c9c55d 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -50,6 +50,7 @@ import { unregisterMigrationIpcHandlers } from '@data/migration/v2' import { application, serviceList } from './core/application' +import { extractRtkBinaries } from './utils/rtk' const logger = loggerService.withContext('MainEntry') @@ -204,6 +205,14 @@ if (!app.requestSingleInstanceLock()) { const { BackupManager } = await import('./services/BackupManager') await BackupManager.handleStartupRestore() + // Extract bundled rtk binary to ~/.cherrystudio/bin/ on first run + // TODO: v2 refactor to use lifecycle + extractRtkBinaries().catch((error) => { + logger.warn('Failed to extract rtk binaries (non-fatal)', { + error: error instanceof Error ? error.message : String(error) + }) + }) + // Start lifecycle (BeforeReady runs parallel with app.whenReady) application.registerAll(serviceList) const bootstrapPromise = application.bootstrap().catch((error) => { diff --git a/src/main/services/AppUpdaterService.ts b/src/main/services/AppUpdaterService.ts index 98bbea1aa5..545cf5dcbb 100644 --- a/src/main/services/AppUpdaterService.ts +++ b/src/main/services/AppUpdaterService.ts @@ -237,8 +237,12 @@ export class AppUpdaterService extends BaseService { const channelConfig = versionConfig.channels[requestedChannel] const latestChannelConfig = versionConfig.channels[UpgradeChannel.LATEST] + if (!semver.gte(currentVersion, versionConfig.minCompatibleVersion)) { + continue + } + // Check version compatibility and channel availability - if (semver.gte(currentVersion, versionConfig.minCompatibleVersion) && channelConfig !== null) { + if (channelConfig !== null) { logger.info( `Found compatible version: ${versionKey} (minCompatibleVersion: ${versionConfig.minCompatibleVersion}), version: ${channelConfig.version}` ) @@ -255,6 +259,12 @@ export class AppUpdaterService extends BaseService { } return { config: channelConfig, channel: requestedChannel } + } else if (requestedChannel !== UpgradeChannel.LATEST && latestChannelConfig !== null) { + // Fallback: requested channel (rc/beta) is null, but latest channel is available + logger.info( + `Requested channel ${requestedChannel} is null for ${versionKey}, falling back to latest channel: ${latestChannelConfig.version}` + ) + return { config: latestChannelConfig, channel: UpgradeChannel.LATEST } } } diff --git a/src/main/services/BackupManager.ts b/src/main/services/BackupManager.ts index e3e19a5230..b4d3084380 100644 --- a/src/main/services/BackupManager.ts +++ b/src/main/services/BackupManager.ts @@ -211,7 +211,7 @@ class BackupManager { const backupedFilePath = path.join(destinationPath, fileName) const output = fs.createWriteStream(backupedFilePath) const archive = archiver('zip', { - zlib: { level: 0 }, // No compression - data is already compressed by LevelDB + zlib: { level: 1 }, // Use lowest compression level for speed (same as legacy backup) zip64: true }) diff --git a/src/main/services/__tests__/AppUpdaterService.test.ts b/src/main/services/__tests__/AppUpdaterService.test.ts index 0537e70f48..663cd1d008 100644 --- a/src/main/services/__tests__/AppUpdaterService.test.ts +++ b/src/main/services/__tests__/AppUpdaterService.test.ts @@ -727,7 +727,7 @@ describe('AppUpdaterService', () => { }) }) - it('should return null when no version has the requested channel', () => { + it('should fallback to latest channel when requested channel is null', () => { const configWithoutRc = { lastUpdated: '2025-01-05T00:00:00Z', versions: { @@ -753,6 +753,30 @@ describe('AppUpdaterService', () => { const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', configWithoutRc) + expect(result).toEqual({ + config: configWithoutRc.versions['1.6.7'].channels.latest, + channel: 'latest' + }) + }) + + it('should return null when no version has the requested channel or latest channel', () => { + const configWithoutAny = { + lastUpdated: '2025-01-05T00:00:00Z', + versions: { + '1.6.7': { + minCompatibleVersion: '1.0.0', + description: 'v1.6.7', + channels: { + latest: null, + rc: null, + beta: null + } + } + } + } + + const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', configWithoutAny) + expect(result).toBeNull() }) }) diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 5a9587fccb..4e3ff62fc4 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -1,4 +1,5 @@ // src/main/services/agents/services/claudecode/index.ts +import { fork } from 'node:child_process' import { EventEmitter } from 'node:events' import { createRequire } from 'node:module' import path from 'node:path' @@ -9,7 +10,8 @@ import type { McpHttpServerConfig, Options, SDKMessage, - SdkPluginConfig + SdkPluginConfig, + SpawnedProcess } from '@anthropic-ai/claude-agent-sdk' import { query } from '@anthropic-ai/claude-agent-sdk' import { loggerService } from '@logger' @@ -19,6 +21,7 @@ import { application } from '@main/core/application' import { pluginService } from '@main/services/agents/plugins/PluginService' import { getAppLanguage } from '@main/utils/language' import { autoDiscoverGitBash } from '@main/utils/process' +import { rtkRewrite } from '@main/utils/rtk' import getLoginShellEnvironment from '@main/utils/shell-env' import { languageEnglishNameMap } from '@shared/config/languages' import { withoutTrailingApiVersion } from '@shared/utils' @@ -314,6 +317,37 @@ class ClaudeCodeService implements AgentServiceInterface { return {} } + const rtkRewriteHook: HookCallback = async (input) => { + if (input.hook_event_name !== 'PreToolUse') { + return {} + } + + // Only rewrite Bash tool commands + if (input.tool_name !== 'Bash' && input.tool_name !== 'builtin_Bash') { + return {} + } + + const toolInput = input.tool_input as Record | undefined + const command = toolInput?.command + if (typeof command !== 'string' || !command.trim()) { + return {} + } + + const rewritten = await rtkRewrite(command) + if (!rewritten) { + return {} + } + + logger.info('rtk rewrote Bash command', { original: command, rewritten }) + + return { + hookSpecificOutput: { + hookEventName: 'PreToolUse', + updatedInput: { ...toolInput, command: rewritten } + } + } + } + // Build SDK options from parameters const options: Options = { abortController, @@ -325,6 +359,20 @@ class ClaudeCodeService implements AgentServiceInterface { logger.warn('claude stderr', { chunk }) errorChunks.push(chunk) }, + spawnClaudeCodeProcess: (spawnOptions) => { + const child = fork(spawnOptions.args[0], spawnOptions.args.slice(1), { + cwd: spawnOptions.cwd, + env: spawnOptions.env as NodeJS.ProcessEnv, + stdio: ['pipe', 'pipe', 'pipe', 'ipc'], + signal: spawnOptions.signal + }) + child.stderr?.on('data', (data: Buffer) => { + const text = data.toString() + logger.warn('claude stderr', { chunk: text }) + errorChunks.push(text) + }) + return child as unknown as SpawnedProcess + }, systemPrompt: session.instructions ? { type: 'preset', @@ -346,7 +394,7 @@ class ClaudeCodeService implements AgentServiceInterface { hooks: { PreToolUse: [ { - hooks: [preToolUseHook] + hooks: [rtkRewriteHook, preToolUseHook] } ] }, diff --git a/src/main/utils/__tests__/rtk.test.ts b/src/main/utils/__tests__/rtk.test.ts new file mode 100644 index 0000000000..811035e54a --- /dev/null +++ b/src/main/utils/__tests__/rtk.test.ts @@ -0,0 +1,168 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock dependencies before importing the module +vi.mock('node:child_process', () => ({ + execFile: vi.fn() +})) + +vi.mock('node:fs', () => ({ + default: { + existsSync: vi.fn(), + mkdirSync: vi.fn(), + copyFileSync: vi.fn(), + chmodSync: vi.fn(), + statSync: vi.fn(), + readFileSync: vi.fn(), + writeFileSync: vi.fn() + } +})) + +vi.mock('node:os', () => ({ + default: { + homedir: () => '/home/testuser' + } +})) + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + }) + } +})) + +vi.mock('@shared/config/constant', () => ({ + HOME_CHERRY_DIR: '.cherrystudio' +})) + +vi.mock('electron', () => ({ + app: { + isPackaged: false + } +})) + +vi.mock('../../constant', () => ({ + isWin: false +})) + +vi.mock('..', () => ({ + getResourcePath: () => '/app/resources' +})) + +vi.mock('semver', () => ({ + gte: (version: string, range: string) => { + const [aMaj, aMin, aPat] = version.split('.').map(Number) + const [bMaj, bMin, bPat] = range.split('.').map(Number) + if (aMaj !== bMaj) return aMaj > bMaj + if (aMin !== bMin) return aMin > bMin + return aPat >= bPat + } +})) + +import { execFile } from 'node:child_process' +import fs from 'node:fs' + +import { extractRtkBinaries, rtkRewrite } from '../rtk' + +const mockExecFile = vi.mocked(execFile) +const mockFs = vi.mocked(fs) + +describe('rtk utils', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('extractRtkBinaries', () => { + it('should skip when bundled dir does not exist', async () => { + mockFs.existsSync.mockReturnValue(false) + + await extractRtkBinaries() + + expect(mockFs.copyFileSync).not.toHaveBeenCalled() + }) + + it('should copy binary when destination does not exist', async () => { + mockFs.existsSync.mockImplementation((p: fs.PathLike) => { + const filePath = String(p) + if (filePath.includes('resources/binaries')) return true + if (filePath.includes('rtk') && filePath.includes('.cherrystudio')) return false + if (filePath.includes('.rtk-version') && filePath.includes('resources')) return true + if (filePath.includes('.rtk-version') && filePath.includes('.cherrystudio')) return false + return true + }) + mockFs.readFileSync.mockReturnValue('0.30.1') + + await extractRtkBinaries() + + expect(mockFs.copyFileSync).toHaveBeenCalled() + expect(mockFs.chmodSync).toHaveBeenCalledWith(expect.any(String), 0o755) + }) + + it('should skip copy when version matches', async () => { + mockFs.existsSync.mockReturnValue(true) + mockFs.readFileSync.mockReturnValue('0.30.1') + + await extractRtkBinaries() + + expect(mockFs.copyFileSync).not.toHaveBeenCalled() + }) + }) + + describe('rtkRewrite', () => { + it('should return null when rtk binary is not found', async () => { + mockFs.existsSync.mockReturnValue(false) + + const result = await rtkRewrite('ls -la') + + expect(result).toBeNull() + }) + + it('should return null when rewritten command equals original', async () => { + mockFs.existsSync.mockReturnValue(true) + + // First call: version check, second call: rewrite + let callCount = 0 + mockExecFile.mockImplementation((_cmd, _args, _opts, callback?) => { + const cb = typeof _opts === 'function' ? _opts : callback + callCount++ + if (callCount === 1) { + ;(cb as (...args: unknown[]) => void)(null, 'rtk 0.30.1', '') + } else { + ;(cb as (...args: unknown[]) => void)(null, 'ls -la', '') + } + return {} as ReturnType + }) + + const result = await rtkRewrite('ls -la') + + expect(result).toBeNull() + }) + + it('should return null when rtk exits with error (no rewrite available)', async () => { + mockFs.existsSync.mockReturnValue(true) + + let callCount = 0 + mockExecFile.mockImplementation((_cmd, _args, _opts, callback?) => { + const cb = typeof _opts === 'function' ? _opts : callback + callCount++ + if (callCount === 1) { + ;(cb as (...args: unknown[]) => void)(null, 'rtk 0.30.1', '') + } else { + ;(cb as (...args: unknown[]) => void)(new Error('exit code 1'), '', '') + } + return {} as ReturnType + }) + + const result = await rtkRewrite('some-command') + + expect(result).toBeNull() + }) + }) +}) diff --git a/src/main/utils/rtk.ts b/src/main/utils/rtk.ts new file mode 100644 index 0000000000..522945b715 --- /dev/null +++ b/src/main/utils/rtk.ts @@ -0,0 +1,175 @@ +import { execFile } from 'node:child_process' +import fs from 'node:fs' +import os from 'node:os' +import path from 'node:path' +import { promisify } from 'node:util' + +import { loggerService } from '@logger' +import { HOME_CHERRY_DIR } from '@shared/config/constant' +import { app } from 'electron' +import { gte as semverGte } from 'semver' + +import { isWin } from '../constant' +import { getResourcePath } from '.' + +const execFileAsync = promisify(execFile) +const logger = loggerService.withContext('Utils:Rtk') + +const RTK_BINARY = isWin ? 'rtk.exe' : 'rtk' +const RTK_VERSION_FILE = '.rtk-version' +const RTK_MIN_VERSION = '0.23.0' +const REWRITE_TIMEOUT_MS = 3000 + +// rtk is not available for these platforms +const UNSUPPORTED_PLATFORMS = new Set(['win32-arm64']) + +let rtkPath: string | null = null +let rtkAvailable: boolean | null = null + +function getPlatformKey(): string { + return `${process.platform}-${process.arch}` +} + +function isPlatformSupported(): boolean { + return !UNSUPPORTED_PLATFORMS.has(getPlatformKey()) +} + +function getBundledBinariesDir(): string { + const dir = path.join(getResourcePath(), 'binaries', getPlatformKey()) + if (app.isPackaged) { + return dir.replace(/\.asar([\\/])/, '.asar.unpacked$1') + } + return dir +} + +function getUserBinDir(): string { + return path.join(os.homedir(), HOME_CHERRY_DIR, 'bin') +} + +/** + * Extract bundled rtk binary to ~/.cherrystudio/bin/ if not already present or outdated. + * Called once at app startup. + */ +export async function extractRtkBinaries(): Promise { + if (!isPlatformSupported()) { + logger.debug('rtk not supported on this platform', { platform: getPlatformKey() }) + return + } + + const bundledDir = getBundledBinariesDir() + if (!fs.existsSync(bundledDir)) { + logger.debug('No bundled rtk binaries found for this platform', { dir: bundledDir }) + return + } + + const userBinDir = getUserBinDir() + fs.mkdirSync(userBinDir, { recursive: true }) + + const src = path.join(bundledDir, RTK_BINARY) + const dest = path.join(userBinDir, RTK_BINARY) + + if (!fs.existsSync(src)) { + return + } + + // Use a version file to detect upgrades instead of comparing file sizes + const bundledVersionFile = path.join(bundledDir, RTK_VERSION_FILE) + const installedVersionFile = path.join(userBinDir, RTK_VERSION_FILE) + const bundledVersion = fs.existsSync(bundledVersionFile) ? fs.readFileSync(bundledVersionFile, 'utf8').trim() : '' + const installedVersion = fs.existsSync(installedVersionFile) + ? fs.readFileSync(installedVersionFile, 'utf8').trim() + : '' + + const shouldCopy = !fs.existsSync(dest) || (bundledVersion && bundledVersion !== installedVersion) + + if (shouldCopy) { + fs.copyFileSync(src, dest) + if (!isWin) { + fs.chmodSync(dest, 0o755) + } + if (bundledVersion) { + fs.writeFileSync(installedVersionFile, bundledVersion, 'utf8') + } + logger.info('Extracted rtk binary to user bin dir', { dest, version: bundledVersion || 'unknown' }) + } +} + +function resolveRtkPath(): string | null { + const userBinPath = path.join(getUserBinDir(), RTK_BINARY) + if (fs.existsSync(userBinPath)) { + return userBinPath + } + + const bundledPath = path.join(getBundledBinariesDir(), RTK_BINARY) + if (fs.existsSync(bundledPath)) { + return bundledPath + } + + return null +} + +async function checkRtkAvailable(): Promise { + if (rtkAvailable !== null) return rtkAvailable + + if (!isPlatformSupported()) { + rtkAvailable = false + return false + } + + rtkPath = resolveRtkPath() + if (!rtkPath) { + rtkAvailable = false + logger.debug('rtk binary not found') + return false + } + + try { + const { stdout } = await execFileAsync(rtkPath, ['--version'], { + timeout: REWRITE_TIMEOUT_MS + }) + const match = stdout.match(/(\d+\.\d+\.\d+)/) + if (match) { + const version = match[1] + if (!semverGte(version, RTK_MIN_VERSION)) { + logger.warn(`rtk version too old (need >= ${RTK_MIN_VERSION})`, { version }) + rtkAvailable = false + return false + } + logger.info('rtk available', { version, path: rtkPath }) + } + rtkAvailable = true + } catch (error) { + logger.warn('Failed to check rtk version', { + error: error instanceof Error ? error.message : String(error) + }) + rtkAvailable = false + } + + return rtkAvailable +} + +/** + * Rewrite a shell command using rtk for token-optimized output. + * Returns the rewritten command, or null if no rewrite is available. + */ +export async function rtkRewrite(command: string): Promise { + if (!(await checkRtkAvailable()) || !rtkPath) { + return null + } + + try { + const { stdout } = await execFileAsync(rtkPath, ['rewrite', command], { + timeout: REWRITE_TIMEOUT_MS + }) + const rewritten = stdout.trim() + + if (!rewritten || rewritten === command) { + return null + } + + return rewritten + } catch { + // rtk rewrite exits 1 when there's no rewrite — expected behavior + return null + } +} diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 2e9a764a40..9fa8705765 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -634,6 +634,7 @@ function buildGenericProviderOptions( let providerOptions: Record = {} const reasoningParams = getReasoningEffort(assistant, model) + logger.debug('reasoningParams', reasoningParams) providerOptions = { ...providerOptions, ...reasoningParams diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 35cc2a7c5f..22b0979d8a 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -85,6 +85,21 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { reasoning: { enabled: false, exclude: true } } } + // nvidia: must use chat_template_kwargs + // Since limited documentation, it's hard to find what parameters should be set + // only part of mainstream oss model covered, all verified by nvidia api + if (model.provider === SystemProviderIds.nvidia) { + if (isSupportedThinkingTokenQwenModel(model)) { + return { chat_template_kwargs: { enable_thinking: false } } + } else if (isDeepSeekHybridInferenceModel(model)) { + return { chat_template_kwargs: { thinking: false } } + } else if (isSupportedThinkingTokenKimiModel(model)) { + return { chat_template_kwargs: { thinking: false } } + } else if (isSupportedThinkingTokenZhipuModel(model)) { + return { chat_template_kwargs: { enable_thinking: false } } + } + } + // providers that use enable_thinking if ( (isSupportEnableThinkingProvider(provider) && @@ -252,6 +267,27 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin budgetTokens = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min) } + // nvidia: must use chat_template_kwargs + // Since limited documentation, it's hard to find what parameters should be set + // only part of mainstream oss model covered, all verified by nvidia api + if (model.provider === SystemProviderIds.nvidia) { + if (isSupportedThinkingTokenQwenModel(model)) { + const enableThinkingConfig = isQwenAlwaysThinkModel(model) ? {} : { enable_thinking: true } + return { + chat_template_kwargs: { + ...enableThinkingConfig, + thinking_budget: budgetTokens + } + } + } else if (isDeepSeekHybridInferenceModel(model)) { + return { chat_template_kwargs: { thinking: true } } + } else if (isSupportedThinkingTokenKimiModel(model)) { + return { chat_template_kwargs: { thinking: true } } + } else if (isSupportedThinkingTokenZhipuModel(model)) { + return { chat_template_kwargs: { enable_thinking: true } } + } + } + // See https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions if (model.provider === SystemProviderIds.silicon) { if ( @@ -310,12 +346,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin enabled: true } } - case 'nvidia': - return { - chat_template_kwargs: { - thinking: true - } - } default: break } @@ -393,9 +423,9 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin } } else { return { - thinking_budget: budgetTokens, chat_template_kwargs: { - ...enableThinkingConfig + ...enableThinkingConfig, + thinking_budget: budgetTokens } } } diff --git a/src/renderer/src/config/models/__tests__/reasoning.test.ts b/src/renderer/src/config/models/__tests__/reasoning.test.ts index 8b25739cb4..c516131871 100644 --- a/src/renderer/src/config/models/__tests__/reasoning.test.ts +++ b/src/renderer/src/config/models/__tests__/reasoning.test.ts @@ -2653,6 +2653,44 @@ describe('Kimi Models', () => { }) }) +describe('isSupportedThinkingTokenZhipuModel', () => { + it('matches GLM-5 series (with or without hyphen)', () => { + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm5' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-5' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-5-plus' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'GLM-5-Pro' }))).toBe(true) + }) + + it('matches GLM-4.5 / 4.6 / 4.7 series', () => { + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.5' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.6' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.7' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.6-pro' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.5-flash' }))).toBe(true) + }) + + it('rejects GLM-4 base and GLM-Z1 models', () => { + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4-plus' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.0' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-4.3' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-z1' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'glm-z1-plus' }))).toBe(false) + }) + + it('rejects unrelated model IDs', () => { + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'gpt-4o' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(false) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'deepseek-v3' }))).toBe(false) + }) + + it('handles provider-prefixed model IDs', () => { + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'accounts/fireworks/models/glm-4p7' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'accounts/fireworks/models/glm-4p5' }))).toBe(true) + expect(isSupportedThinkingTokenZhipuModel(createModel({ id: 'zhipu/glm-4.6' }))).toBe(true) + }) +}) + describe('Fireworks provider model name normalization', () => { it('should detect DeepSeek hybrid inference models from Fireworks', () => { expect(isDeepSeekHybridInferenceModel(createModel({ id: 'accounts/fireworks/models/deepseek-v3p2' }))).toBe(true) diff --git a/src/renderer/src/config/models/reasoning.ts b/src/renderer/src/config/models/reasoning.ts index 99964f671e..c19cf748d0 100644 --- a/src/renderer/src/config/models/reasoning.ts +++ b/src/renderer/src/config/models/reasoning.ts @@ -592,9 +592,19 @@ export const isSupportedReasoningEffortPerplexityModel = (model: Model): boolean return modelId.includes('sonar-deep-research') } +/** + * Checks whether a Zhipu model supports thinking token control. + * + * Matches model IDs containing: + * - `glm5` or `glm-5` (GLM-5 series) + * - `glm-4.5`, `glm-4.6`, `glm-4.7` (GLM-4.x advanced series) + * + * Note: GLM-Z1 reasoning models are NOT included here — they are covered + * by {@link isZhipuReasoningModel} instead. + */ export const isSupportedThinkingTokenZhipuModel = (model: Model): boolean => { const modelId = getLowerBaseModelName(model.id, '/') - return ['glm-5', 'glm-4.5', 'glm-4.6', 'glm-4.7'].some((id) => modelId.includes(id)) + return /glm-?5|glm-4\.[567]/.test(modelId) } export const isSupportedThinkingTokenMiMoModel = (model: Model): boolean => { diff --git a/src/renderer/src/types/sdk.ts b/src/renderer/src/types/sdk.ts index 45266b6d23..d47eefb6c0 100644 --- a/src/renderer/src/types/sdk.ts +++ b/src/renderer/src/types/sdk.ts @@ -89,6 +89,9 @@ export type ReasoningEffortOptionalParams = { chat_template_kwargs?: { thinking?: boolean enable_thinking?: boolean + // mainstream inference backend doesn't support thinking_budget, so it may not work as expected + // https://github.com/vllm-project/vllm/issues/17887 + thinking_budget?: number } extra_body?: { google?: {