mirror of
https://github.com/larksuite/cli.git
synced 2026-07-03 22:24:31 +08:00
Compare commits
22 Commits
coderabbit
...
feat/app_r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e92620ba3b | ||
|
|
146f13e5e2 | ||
|
|
3c35f3e3f5 | ||
|
|
a2d8e21552 | ||
|
|
788c382984 | ||
|
|
984ebf97b1 | ||
|
|
d0bbc22b36 | ||
|
|
1142f26051 | ||
|
|
3b6086525d | ||
|
|
08ab54cb0f | ||
|
|
91cd101040 | ||
|
|
b4225b9382 | ||
|
|
d42a0807f0 | ||
|
|
c477911354 | ||
|
|
6b3d83224c | ||
|
|
99830f4d6c | ||
|
|
909626db8f | ||
|
|
e6c8fd546c | ||
|
|
40de8a44dc | ||
|
|
29fa49fa5f | ||
|
|
7575d72c00 | ||
|
|
41c9a30ba5 |
19
.github/workflows/release.yml
vendored
19
.github/workflows/release.yml
vendored
@@ -9,7 +9,11 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
# All platforms (incl. darwin keychain_signer) are CGO-free and cross-compiled
|
||||
# on a single ubuntu runner in one goreleaser run (one checksums.txt). The
|
||||
# darwin signer's runtime FFI is validated separately by the signer-test job.
|
||||
goreleaser:
|
||||
needs: signer-test-macos
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -34,6 +38,21 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Validate the macOS keychain signer on real hardware. The release binaries are
|
||||
# cross-compiled on ubuntu (CGO-free purego FFI), so this is the only step that
|
||||
# needs a Mac — and it gates the release rather than producing it.
|
||||
signer-test-macos:
|
||||
runs-on: macos-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
- uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
- name: Keychain signer round-trip (CGO-free purego FFI)
|
||||
run: LARK_KEYCHAIN_IT=1 CGO_ENABLED=0 go test -tags keychain_signer -run Keychain -v ./internal/keysigner/
|
||||
|
||||
publish-npm:
|
||||
needs: goreleaser
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
30
.github/workflows/semantic-review.yml
vendored
30
.github/workflows/semantic-review.yml
vendored
@@ -47,13 +47,10 @@ jobs:
|
||||
throw new Error(`ambiguous workflow_run pull request bindings: ${runPRs.length}`);
|
||||
}
|
||||
let prNumber = Number(runPRs[0]?.number || 0);
|
||||
const eventBaseSha = runPRs[0]?.base?.sha || "";
|
||||
let eventBaseSha = runPRs[0]?.base?.sha || "";
|
||||
const eventHeadSha = runPRs[0]?.head?.sha || "";
|
||||
const targetHeadSha = run.head_sha;
|
||||
const targetHeadSha = eventHeadSha || run.head_sha;
|
||||
if (!/^[a-f0-9]{40}$/i.test(targetHeadSha)) throw new Error("invalid PR head sha");
|
||||
if (eventHeadSha && eventHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()) {
|
||||
core.notice("PR quality summary using workflow_run head_sha because workflow_run pull request head differs from the CI run head");
|
||||
}
|
||||
|
||||
const factsArtifactPattern = /^quality-gate-facts-([a-f0-9]{40})-([a-f0-9]{40})$/i;
|
||||
const { data: artifactData } = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
@@ -74,11 +71,11 @@ jobs:
|
||||
if (artifactHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()) {
|
||||
artifactError = "facts artifact head sha does not match verified PR head sha";
|
||||
factsArtifactName = "";
|
||||
} else if (eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()) {
|
||||
artifactError = "facts artifact base sha does not match workflow_run pull request base sha";
|
||||
factsArtifactName = "";
|
||||
} else {
|
||||
artifactBaseSha = parsedBaseSha;
|
||||
if (eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()) {
|
||||
core.notice("PR quality summary using facts artifact base because workflow_run pull request base differs from the CI facts artifact base");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
@@ -126,7 +123,7 @@ jobs:
|
||||
core.setOutput("stale", "true");
|
||||
return;
|
||||
}
|
||||
const baseSha = artifactBaseSha || eventBaseSha || pr.base.sha;
|
||||
const baseSha = eventBaseSha || artifactBaseSha || pr.base.sha;
|
||||
if (!/^[a-f0-9]{40}$/i.test(baseSha)) throw new Error("invalid PR base sha");
|
||||
if ((eventBaseSha || artifactBaseSha) && pr.base.sha !== baseSha) {
|
||||
core.notice("PR quality summary skipped: workflow_run is stale for this PR base");
|
||||
@@ -258,13 +255,10 @@ jobs:
|
||||
throw new Error(`ambiguous workflow_run pull request bindings: ${runPRs.length}`);
|
||||
}
|
||||
let prNumber = Number(runPRs[0]?.number || 0);
|
||||
const eventBaseSha = runPRs[0]?.base?.sha || "";
|
||||
let eventBaseSha = runPRs[0]?.base?.sha || "";
|
||||
const eventHeadSha = runPRs[0]?.head?.sha || "";
|
||||
const targetHeadSha = run.head_sha;
|
||||
const targetHeadSha = eventHeadSha || run.head_sha;
|
||||
if (!/^[a-f0-9]{40}$/i.test(targetHeadSha)) throw new Error("invalid PR head sha");
|
||||
if (eventHeadSha && eventHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()) {
|
||||
core.notice("semantic review using workflow_run head_sha because workflow_run pull request head differs from the CI run head");
|
||||
}
|
||||
|
||||
const factsArtifactPattern = /^quality-gate-facts-([a-f0-9]{40})-([a-f0-9]{40})$/i;
|
||||
const { data: artifactData } = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
@@ -285,11 +279,11 @@ jobs:
|
||||
if (artifactHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()) {
|
||||
artifactError = "facts artifact head sha does not match verified PR head sha";
|
||||
factsArtifactName = "";
|
||||
} else if (eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()) {
|
||||
artifactError = "facts artifact base sha does not match workflow_run pull request base sha";
|
||||
factsArtifactName = "";
|
||||
} else {
|
||||
artifactBaseSha = parsedBaseSha;
|
||||
if (eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()) {
|
||||
core.notice("semantic review using facts artifact base because workflow_run pull request base differs from the CI facts artifact base");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
@@ -337,7 +331,7 @@ jobs:
|
||||
core.setOutput("stale", "true");
|
||||
return;
|
||||
}
|
||||
const baseSha = artifactBaseSha || eventBaseSha || pr.base.sha;
|
||||
const baseSha = eventBaseSha || artifactBaseSha || pr.base.sha;
|
||||
if (!/^[a-f0-9]{40}$/i.test(baseSha)) throw new Error("invalid PR base sha");
|
||||
if ((eventBaseSha || artifactBaseSha) && pr.base.sha !== baseSha) {
|
||||
core.notice("semantic review skipped: workflow_run is stale for this PR base");
|
||||
|
||||
@@ -5,15 +5,53 @@ before:
|
||||
- python3 scripts/fetch_meta.py
|
||||
|
||||
builds:
|
||||
- binary: lark-cli
|
||||
# Linux & Windows: pure-Go TPM 2.0 signer is compiled in by default (no build
|
||||
# tag), cross-compiled with CGO disabled — the binaries ship the platform key
|
||||
# signer for private_key_jwt. windows/arm64 is the one exception: the sks
|
||||
# Windows dependency stack (go-ole) has no arm64 support, so the signer file is
|
||||
# arch-excluded there and that binary falls back to client_secret only.
|
||||
- id: linux
|
||||
binary: lark-cli
|
||||
main: .
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
flags:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w -X github.com/larksuite/cli/internal/build.Version={{ .Version }} -X github.com/larksuite/cli/internal/build.Date={{ .Date }}
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- id: windows
|
||||
binary: lark-cli
|
||||
main: .
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
flags:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w -X github.com/larksuite/cli/internal/build.Version={{ .Version }} -X github.com/larksuite/cli/internal/build.Date={{ .Date }}
|
||||
goos:
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
# macOS: the keychain signer calls Security.framework via runtime FFI (purego),
|
||||
# so it is CGO-free, compiled into every darwin build (no build tag), and
|
||||
# cross-compiles from the same ubuntu runner as linux/windows.
|
||||
- id: darwin
|
||||
binary: lark-cli
|
||||
main: .
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
flags:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w -X github.com/larksuite/cli/internal/build.Version={{ .Version }} -X github.com/larksuite/cli/internal/build.Date={{ .Date }}
|
||||
goos:
|
||||
- darwin
|
||||
- linux
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
@@ -23,7 +61,7 @@ archives:
|
||||
- name_template: "lark-cli-{{ .Version }}-{{ .Os }}-{{ .Arch }}"
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
formats: [zip]
|
||||
files:
|
||||
- README.md
|
||||
- LICENSE
|
||||
|
||||
25
CHANGELOG.md
25
CHANGELOG.md
@@ -2,30 +2,6 @@
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [v1.0.57] - 2026-06-23
|
||||
|
||||
### Features
|
||||
|
||||
- **slides**: Add `+screenshot` to capture slide page images (or render a single `<slide>` XML snippet), returning the local file path instead of Base64 (#1358)
|
||||
- **base**: Support record comments (#1043)
|
||||
- **search**: Surface search API notices (#1413)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- **mail**: Resolve folder/label filter once per `+triage list` call (#1512)
|
||||
- **meta**: Backfill enum value descriptions from options (#1541)
|
||||
- **cli**: Add missing CLI headers for git credential helper (#1539)
|
||||
|
||||
### Documentation
|
||||
|
||||
- **doc**: Refine rich block, path, and block ID guidance (#1508)
|
||||
- **mail**: Trim lark-mail skill context (#1527)
|
||||
- **drive**: Add permission governance workflow guidance (#1292)
|
||||
|
||||
### Build
|
||||
|
||||
- **ci**: Bind semantic review to workflow run head (#1551)
|
||||
|
||||
## [v1.0.56] - 2026-06-18
|
||||
|
||||
### Features
|
||||
@@ -1236,7 +1212,6 @@ Bundled AI agent skills for intelligent assistance:
|
||||
- Bilingual documentation (English & Chinese).
|
||||
- CI/CD pipelines: linting, testing, coverage reporting, and automated releases.
|
||||
|
||||
[v1.0.57]: https://github.com/larksuite/cli/releases/tag/v1.0.57
|
||||
[v1.0.56]: https://github.com/larksuite/cli/releases/tag/v1.0.56
|
||||
[v1.0.55]: https://github.com/larksuite/cli/releases/tag/v1.0.55
|
||||
[v1.0.54]: https://github.com/larksuite/cli/releases/tag/v1.0.54
|
||||
|
||||
6
Makefile
6
Makefile
@@ -33,7 +33,11 @@ build: fetch_meta
|
||||
go build -trimpath -ldflags "$(LDFLAGS)" -o $(BINARY) .
|
||||
|
||||
vet: fetch_meta
|
||||
go vet ./...
|
||||
# -unsafeptr=false: the macOS keychain signer dereferences dylib data-symbol
|
||||
# addresses from purego.Dlsym (uintptr->unsafe.Pointer over stable C memory) —
|
||||
# safe FFI, but go vet's unsafeptr can't prove it and has no inline suppress.
|
||||
# golangci-lint still runs full govet (honoring the //nolint:govet) in CI.
|
||||
go vet -unsafeptr=false ./...
|
||||
|
||||
# fmt-check fails when any file would be reformatted by gofmt. Keep this
|
||||
# in sync with the fast-gate "Check formatting" step in CI.
|
||||
|
||||
@@ -265,7 +265,7 @@ func authLoginRun(opts *LoginOptions) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authResp, err := larkauth.RequestDeviceAuthorization(httpClient, config.AppID, config.AppSecret, config.Brand, finalScope, f.IOStreams.ErrOut)
|
||||
authResp, err := larkauth.RequestDeviceAuthorization(opts.Ctx, httpClient, larkauth.ClientAuthFromConfig(config), config.Brand, finalScope, f.IOStreams.ErrOut)
|
||||
if err != nil {
|
||||
return errs.NewAuthenticationError(errs.SubtypeUnknown, "device authorization failed: %v", err).WithCause(err)
|
||||
}
|
||||
@@ -325,7 +325,7 @@ func authLoginRun(opts *LoginOptions) error {
|
||||
|
||||
// Step 3: Poll for token
|
||||
log(msg.WaitingAuth)
|
||||
result := pollDeviceToken(opts.Ctx, httpClient, config.AppID, config.AppSecret, config.Brand,
|
||||
result := pollDeviceToken(opts.Ctx, httpClient, larkauth.ClientAuthFromConfig(config), config.Brand,
|
||||
authResp.DeviceCode, authResp.Interval, authResp.ExpiresIn, f.IOStreams.ErrOut)
|
||||
|
||||
if !result.OK {
|
||||
@@ -415,7 +415,7 @@ func authLoginPollDeviceCode(opts *LoginOptions, config *core.CliConfig, msg *lo
|
||||
fmt.Fprintln(f.IOStreams.ErrOut, msg.AgentTimeoutHint)
|
||||
}
|
||||
log(msg.WaitingAuth)
|
||||
result := pollDeviceToken(opts.Ctx, httpClient, config.AppID, config.AppSecret, config.Brand,
|
||||
result := pollDeviceToken(opts.Ctx, httpClient, larkauth.ClientAuthFromConfig(config), config.Brand,
|
||||
opts.DeviceCode, 5, 600, f.IOStreams.ErrOut)
|
||||
|
||||
if !result.OK {
|
||||
|
||||
@@ -847,7 +847,7 @@ func TestAuthLoginRun_DeviceCodeTokenNilCleansScopeCache(t *testing.T) {
|
||||
|
||||
original := pollDeviceToken
|
||||
t.Cleanup(func() { pollDeviceToken = original })
|
||||
pollDeviceToken = func(ctx context.Context, httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *larkauth.DeviceFlowResult {
|
||||
pollDeviceToken = func(ctx context.Context, httpClient *http.Client, ca larkauth.ClientAuth, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *larkauth.DeviceFlowResult {
|
||||
return &larkauth.DeviceFlowResult{OK: true, Token: nil}
|
||||
}
|
||||
|
||||
@@ -886,7 +886,7 @@ func TestAuthLoginRun_JSONAbort_StdoutEventOnly_StderrEmpty(t *testing.T) {
|
||||
|
||||
original := pollDeviceToken
|
||||
t.Cleanup(func() { pollDeviceToken = original })
|
||||
pollDeviceToken = func(ctx context.Context, httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *larkauth.DeviceFlowResult {
|
||||
pollDeviceToken = func(ctx context.Context, httpClient *http.Client, ca larkauth.ClientAuth, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *larkauth.DeviceFlowResult {
|
||||
return &larkauth.DeviceFlowResult{OK: false, Message: "user denied"}
|
||||
}
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ func TestSaveInitConfig_OmitLangPreservesPrior(t *testing.T) {
|
||||
t.Fatalf("seed config: %v", err)
|
||||
}
|
||||
|
||||
if err := saveInitConfig("", existing, f, "cli_x", core.PlainSecret("s2"), core.BrandFeishu, ""); err != nil {
|
||||
if err := saveInitConfig("", existing, f, "cli_x", core.PlainSecret("s2"), core.BrandFeishu, "", "", nil); err != nil {
|
||||
t.Fatalf("saveInitConfig (no --lang): %v", err)
|
||||
}
|
||||
|
||||
@@ -206,6 +206,88 @@ func TestSaveInitConfig_OmitLangPreservesPrior(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRefFromResult_PrivateKeyJWT(t *testing.T) {
|
||||
ref := keyRefFromResult(&configInitResult{
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
KeyLabel: "lark-cli-default",
|
||||
})
|
||||
if ref == nil {
|
||||
t.Fatal("keyRefFromResult returned nil")
|
||||
}
|
||||
if ref.Source != "tee" || ref.ID != "lark-cli-default" {
|
||||
t.Fatalf("key ref = %#v, want tee/lark-cli-default", ref)
|
||||
}
|
||||
|
||||
if ref := keyRefFromResult(&configInitResult{AuthMethod: core.AuthMethodPrivateKeyJWT}); ref != nil {
|
||||
t.Fatalf("missing key label should not persist key ref, got %#v", ref)
|
||||
}
|
||||
if ref := keyRefFromResult(&configInitResult{AuthMethod: core.AuthMethodClientSecret, KeyLabel: "ignored"}); ref != nil {
|
||||
t.Fatalf("client_secret should not persist key ref, got %#v", ref)
|
||||
}
|
||||
if ref := keyRefFromResult(nil); ref != nil {
|
||||
t.Fatalf("nil result should not persist key ref, got %#v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveInitConfig_PrivateKeyJWTSingleAppPersistsSecretlessAuth(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
f, _, _, _ := cmdutil.TestFactory(t, nil)
|
||||
|
||||
keyRef := &core.SecretRef{Source: "tee", ID: "lark-cli-default"}
|
||||
if err := saveInitConfig("", nil, f, "cli_pkjwt", core.SecretInput{}, core.BrandFeishu, "en_us", core.AuthMethodPrivateKeyJWT, keyRef); err != nil {
|
||||
t.Fatalf("saveInitConfig private_key_jwt single app: %v", err)
|
||||
}
|
||||
|
||||
got, err := core.LoadMultiAppConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadMultiAppConfig: %v", err)
|
||||
}
|
||||
if len(got.Apps) != 1 {
|
||||
t.Fatalf("apps len = %d, want 1", len(got.Apps))
|
||||
}
|
||||
app := got.Apps[0]
|
||||
if app.AppId != "cli_pkjwt" {
|
||||
t.Fatalf("AppId = %q, want cli_pkjwt", app.AppId)
|
||||
}
|
||||
if app.AuthMethod != core.AuthMethodPrivateKeyJWT {
|
||||
t.Fatalf("AuthMethod = %q, want private_key_jwt", app.AuthMethod)
|
||||
}
|
||||
if app.KeyRef == nil || app.KeyRef.Source != "tee" || app.KeyRef.ID != "lark-cli-default" {
|
||||
t.Fatalf("KeyRef = %#v, want tee/lark-cli-default", app.KeyRef)
|
||||
}
|
||||
if app.AppSecret.Ref != nil || app.AppSecret.Plain != "" {
|
||||
t.Fatalf("private_key_jwt config must stay secretless, AppSecret=%#v", app.AppSecret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveInitConfig_PrivateKeyJWTProfilePersistsSecretlessAuth(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
f, _, _, _ := cmdutil.TestFactory(t, nil)
|
||||
|
||||
keyRef := &core.SecretRef{Source: "tee", ID: "lark-cli-default"}
|
||||
if err := saveInitConfig("prod", &core.MultiAppConfig{}, f, "cli_pkjwt", core.SecretInput{}, core.BrandLark, "en_us", core.AuthMethodPrivateKeyJWT, keyRef); err != nil {
|
||||
t.Fatalf("saveInitConfig private_key_jwt profile: %v", err)
|
||||
}
|
||||
|
||||
got, err := core.LoadMultiAppConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadMultiAppConfig: %v", err)
|
||||
}
|
||||
app := got.FindApp("prod")
|
||||
if app == nil {
|
||||
t.Fatalf("profile prod not saved: %#v", got.Apps)
|
||||
}
|
||||
if app.AuthMethod != core.AuthMethodPrivateKeyJWT {
|
||||
t.Fatalf("AuthMethod = %q, want private_key_jwt", app.AuthMethod)
|
||||
}
|
||||
if app.KeyRef == nil || app.KeyRef.Source != "tee" || app.KeyRef.ID != "lark-cli-default" {
|
||||
t.Fatalf("KeyRef = %#v, want tee/lark-cli-default", app.KeyRef)
|
||||
}
|
||||
if app.AppSecret.Ref != nil || app.AppSecret.Plain != "" {
|
||||
t.Fatalf("private_key_jwt profile must stay secretless, AppSecret=%#v", app.AppSecret)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigInitCmd_InvalidLang verifies a non-empty --lang on config init is
|
||||
// strictly validated the same way bind validates: wrong-case / typo / removed
|
||||
// codes / hyphen form all exit with ExitValidation. (Empty is a no-op.)
|
||||
@@ -388,7 +470,7 @@ func TestSaveAsProfile_RejectsProfileNameCollisionWithExistingAppID(t *testing.T
|
||||
},
|
||||
}
|
||||
|
||||
err := saveAsProfile(existing, keychain.KeychainAccess(&noopConfigKeychain{}), "cli_prod", "app-new", core.PlainSecret("new-secret"), core.BrandLark, "en")
|
||||
err := saveAsProfile(existing, keychain.KeychainAccess(&noopConfigKeychain{}), "cli_prod", "app-new", core.PlainSecret("new-secret"), core.BrandLark, "en", "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected conflict error")
|
||||
}
|
||||
@@ -427,6 +509,46 @@ func TestWrapSaveConfigError_PassesTypedValidationThrough(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAsProfile_UpdatePersistsPrivateKeyJWT(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
|
||||
existing := &core.MultiAppConfig{
|
||||
Apps: []core.AppConfig{{
|
||||
Name: "prod",
|
||||
AppId: "cli_prod",
|
||||
AppSecret: core.PlainSecret("old-secret"),
|
||||
Brand: core.BrandFeishu,
|
||||
Users: []core.AppUser{{UserOpenId: "ou_1", UserName: "User"}},
|
||||
}},
|
||||
}
|
||||
keyRef := &core.SecretRef{Source: "tee", ID: "lark-cli-default"}
|
||||
|
||||
if err := saveAsProfile(existing, keychain.KeychainAccess(&noopConfigKeychain{}), "prod", "cli_prod", core.SecretInput{}, core.BrandLark, "en_us", core.AuthMethodPrivateKeyJWT, keyRef); err != nil {
|
||||
t.Fatalf("saveAsProfile update private_key_jwt: %v", err)
|
||||
}
|
||||
|
||||
got, err := core.LoadMultiAppConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadMultiAppConfig: %v", err)
|
||||
}
|
||||
app := got.FindApp("prod")
|
||||
if app == nil {
|
||||
t.Fatalf("profile prod not saved: %#v", got.Apps)
|
||||
}
|
||||
if app.AuthMethod != core.AuthMethodPrivateKeyJWT {
|
||||
t.Fatalf("AuthMethod = %q, want private_key_jwt", app.AuthMethod)
|
||||
}
|
||||
if app.KeyRef == nil || app.KeyRef.Source != "tee" || app.KeyRef.ID != "lark-cli-default" {
|
||||
t.Fatalf("KeyRef = %#v, want tee/lark-cli-default", app.KeyRef)
|
||||
}
|
||||
if app.AppSecret.Ref != nil || app.AppSecret.Plain != "" {
|
||||
t.Fatalf("private_key_jwt update must stay secretless, AppSecret=%#v", app.AppSecret)
|
||||
}
|
||||
if len(app.Users) != 1 || app.Users[0].UserOpenId != "ou_1" {
|
||||
t.Fatalf("same-app update should preserve users, Users=%#v", app.Users)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateExistingProfileWithoutSecret_RejectsAppIDChange(t *testing.T) {
|
||||
multi := &core.MultiAppConfig{
|
||||
CurrentApp: "prod",
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/i18n"
|
||||
"github.com/larksuite/cli/internal/keychain"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
)
|
||||
|
||||
@@ -31,6 +32,7 @@ type ConfigInitOptions struct {
|
||||
AppSecretStdin bool // read app-secret from stdin (avoids process list exposure)
|
||||
Brand string
|
||||
New bool
|
||||
AuthMethod string // --auth-method for --new: "" (default client_secret) | private_key_jwt
|
||||
|
||||
Lang string // raw --lang (string for cobra); normalized to canonical/"" in validateInitLang
|
||||
langExplicit bool // true when --lang was explicitly passed
|
||||
@@ -39,6 +41,8 @@ type ConfigInitOptions struct {
|
||||
|
||||
ProfileName string // when set, create/update a named profile instead of replacing Apps[0]
|
||||
|
||||
Restore bool // Restore re-registers the app already in config to recover a lost credential
|
||||
|
||||
// ForceInit overrides the agent-workspace guard. Without it, running
|
||||
// init under OPENCLAW_HOME / HERMES_HOME refuses and points the caller
|
||||
// at config bind — which is what AI agents almost always want. Manual
|
||||
@@ -81,11 +85,13 @@ if the user explicitly wants a separate app inside the Agent workspace.`,
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&opts.New, "new", false, "create a new app directly (skip mode selection)")
|
||||
cmd.Flags().StringVar(&opts.AuthMethod, "auth-method", "", "auth method for --new: client_secret (default) or private_key_jwt (signed by a platform key, no app secret)")
|
||||
cmd.Flags().StringVar(&opts.AppID, "app-id", "", "App ID (non-interactive)")
|
||||
cmd.Flags().BoolVar(&opts.AppSecretStdin, "app-secret-stdin", false, "Read App Secret from stdin to avoid process list exposure")
|
||||
cmd.Flags().StringVar(&opts.Brand, "brand", "feishu", "feishu or lark (non-interactive, default feishu)")
|
||||
cmd.Flags().StringVar(&opts.Lang, "lang", "", "language preference (e.g. zh or zh_cn)")
|
||||
cmd.Flags().StringVar(&opts.ProfileName, "name", "", "create or update a named profile (append instead of replace)")
|
||||
cmd.Flags().BoolVar(&opts.Restore, "restore", false, "re-register the app already in config to recover a lost credential (keychain key / app secret); reuses the stored app ID and auth method")
|
||||
cmd.Flags().BoolVar(&opts.ForceInit, "force-init", false, "allow init inside an Agent workspace (OPENCLAW_HOME / HERMES_HOME); use config bind instead unless you really want a separate app")
|
||||
cmdutil.SetRisk(cmd, "write")
|
||||
|
||||
@@ -132,7 +138,7 @@ func guardAgentWorkspace(opts *ConfigInitOptions) error {
|
||||
|
||||
// hasAnyNonInteractiveFlag returns true if any non-interactive flag is set.
|
||||
func (o *ConfigInitOptions) hasAnyNonInteractiveFlag() bool {
|
||||
return o.New || o.AppID != "" || o.AppSecretStdin
|
||||
return o.New || o.Restore || o.AppID != "" || o.AppSecretStdin
|
||||
}
|
||||
|
||||
// cleanupOldConfig clears keychain entries (AppSecret + UAT) for all apps in existing config except the app whose AppId equals skipAppID.
|
||||
@@ -151,11 +157,44 @@ func cleanupOldConfig(existing *core.MultiAppConfig, f *cmdutil.Factory, skipApp
|
||||
}
|
||||
}
|
||||
|
||||
// removeStaleSecretForPKJWT clears a secret left in the keychain when the SAME
|
||||
// appId is migrated from client_secret to private_key_jwt. cleanupOldConfig
|
||||
// explicitly skips a matching appId, and saveAsProfile only cleans up on an
|
||||
// appId change, so a same-appId migration would orphan the old secret. This
|
||||
// fills that gap. RemoveSecretStore only deletes Source=="keychain" entries, so
|
||||
// the new pkjwt tee key handle is never touched.
|
||||
func removeStaleSecretForPKJWT(existing *core.MultiAppConfig, profileName, appID string, kc keychain.KeychainAccess) {
|
||||
if existing == nil {
|
||||
return
|
||||
}
|
||||
var prior *core.AppConfig
|
||||
if profileName != "" {
|
||||
if idx := findProfileIndexByName(existing, profileName); idx >= 0 {
|
||||
prior = &existing.Apps[idx]
|
||||
}
|
||||
} else {
|
||||
prior = existing.CurrentAppConfig("")
|
||||
}
|
||||
if prior != nil && prior.AppId == appID && !prior.AppSecret.IsZero() {
|
||||
core.RemoveSecretStore(prior.AppSecret, kc)
|
||||
}
|
||||
}
|
||||
|
||||
// keyRefFromResult builds the TEE key reference to persist for a private_key_jwt
|
||||
// registration result, or nil for client_secret.
|
||||
func keyRefFromResult(r *configInitResult) *core.SecretRef {
|
||||
if r != nil && r.AuthMethod == core.AuthMethodPrivateKeyJWT && r.KeyLabel != "" {
|
||||
return &core.SecretRef{Source: "tee", ID: r.KeyLabel}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveAsOnlyApp overwrites config.json with a single-app config.
|
||||
func saveAsOnlyApp(appId string, secret core.SecretInput, brand core.LarkBrand, lang string) error {
|
||||
func saveAsOnlyApp(appId string, secret core.SecretInput, brand core.LarkBrand, lang, authMethod string, keyRef *core.SecretRef) error {
|
||||
config := &core.MultiAppConfig{
|
||||
Apps: []core.AppConfig{{
|
||||
AppId: appId, AppSecret: secret, Brand: brand, Lang: i18n.Lang(lang), Users: []core.AppUser{},
|
||||
AuthMethod: authMethod, KeyRef: keyRef,
|
||||
}},
|
||||
}
|
||||
return core.SaveMultiAppConfig(config)
|
||||
@@ -164,9 +203,11 @@ func saveAsOnlyApp(appId string, secret core.SecretInput, brand core.LarkBrand,
|
||||
// saveInitConfig saves a new/updated app config, respecting --profile mode.
|
||||
// With profileName: appends or updates the named profile (preserves other profiles).
|
||||
// Without profileName: cleans up old config and saves as the only app.
|
||||
func saveInitConfig(profileName string, existing *core.MultiAppConfig, f *cmdutil.Factory, appId string, secret core.SecretInput, brand core.LarkBrand, lang string) error {
|
||||
// authMethod/keyRef carry the credential type: ("", nil) for client_secret,
|
||||
// (private_key_jwt, &{tee,label}) for the secretless TEE flow.
|
||||
func saveInitConfig(profileName string, existing *core.MultiAppConfig, f *cmdutil.Factory, appId string, secret core.SecretInput, brand core.LarkBrand, lang, authMethod string, keyRef *core.SecretRef) error {
|
||||
if profileName != "" {
|
||||
return saveAsProfile(existing, f.Keychain, profileName, appId, secret, brand, lang)
|
||||
return saveAsProfile(existing, f.Keychain, profileName, appId, secret, brand, lang, authMethod, keyRef)
|
||||
}
|
||||
cleanupOldConfig(existing, f, appId)
|
||||
var prior i18n.Lang
|
||||
@@ -175,7 +216,7 @@ func saveInitConfig(profileName string, existing *core.MultiAppConfig, f *cmduti
|
||||
prior = app.Lang
|
||||
}
|
||||
}
|
||||
return saveAsOnlyApp(appId, secret, brand, string(preferredLang(i18n.Lang(lang), prior)))
|
||||
return saveAsOnlyApp(appId, secret, brand, string(preferredLang(i18n.Lang(lang), prior)), authMethod, keyRef)
|
||||
}
|
||||
|
||||
// wrapSaveConfigError passes an already-typed error (e.g. the --name conflict
|
||||
@@ -195,7 +236,7 @@ func wrapSaveConfigError(err error) error {
|
||||
// saveAsProfile appends or updates a named profile in the config.
|
||||
// If a profile with the same name exists, it updates it; otherwise appends.
|
||||
// When updating, cleans up old keychain secrets if AppId changed.
|
||||
func saveAsProfile(existing *core.MultiAppConfig, kc keychain.KeychainAccess, profileName, appId string, secret core.SecretInput, brand core.LarkBrand, lang string) error {
|
||||
func saveAsProfile(existing *core.MultiAppConfig, kc keychain.KeychainAccess, profileName, appId string, secret core.SecretInput, brand core.LarkBrand, lang, authMethod string, keyRef *core.SecretRef) error {
|
||||
multi := existing
|
||||
if multi == nil {
|
||||
multi = &core.MultiAppConfig{}
|
||||
@@ -214,6 +255,8 @@ func saveAsProfile(existing *core.MultiAppConfig, kc keychain.KeychainAccess, pr
|
||||
multi.Apps[idx].AppSecret = secret
|
||||
multi.Apps[idx].Brand = brand
|
||||
multi.Apps[idx].Lang = preferredLang(i18n.Lang(lang), multi.Apps[idx].Lang)
|
||||
multi.Apps[idx].AuthMethod = authMethod
|
||||
multi.Apps[idx].KeyRef = keyRef
|
||||
} else {
|
||||
if findAppIndexByAppID(multi, profileName) >= 0 {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument,
|
||||
@@ -222,12 +265,14 @@ func saveAsProfile(existing *core.MultiAppConfig, kc keychain.KeychainAccess, pr
|
||||
}
|
||||
// Append new profile
|
||||
multi.Apps = append(multi.Apps, core.AppConfig{
|
||||
Name: profileName,
|
||||
AppId: appId,
|
||||
AppSecret: secret,
|
||||
Brand: brand,
|
||||
Lang: i18n.Lang(lang),
|
||||
Users: []core.AppUser{},
|
||||
Name: profileName,
|
||||
AppId: appId,
|
||||
AppSecret: secret,
|
||||
Brand: brand,
|
||||
Lang: i18n.Lang(lang),
|
||||
Users: []core.AppUser{},
|
||||
AuthMethod: authMethod,
|
||||
KeyRef: keyRef,
|
||||
})
|
||||
}
|
||||
return core.SaveMultiAppConfig(multi)
|
||||
@@ -305,6 +350,94 @@ func updateExistingProfileWithoutSecret(existing *core.MultiAppConfig, profileNa
|
||||
return core.SaveMultiAppConfig(existing)
|
||||
}
|
||||
|
||||
// persistAndProbeResult saves a registration/restore result into profileName and
|
||||
// runs the post-registration probe. profileName == "" replaces the single app
|
||||
// (legacy); a named profile is updated in place. Shared by --new and --restore.
|
||||
func persistAndProbeResult(opts *ConfigInitOptions, f *cmdutil.Factory, profileName string, result *configInitResult) error {
|
||||
existing, _ := core.LoadMultiAppConfig()
|
||||
|
||||
// private_key_jwt apps have no secret: persist auth method + TEE key ref.
|
||||
// Registration success already validated the key (server bound the public
|
||||
// key), so the app_secret probe is skipped.
|
||||
if result.AuthMethod == core.AuthMethodPrivateKeyJWT {
|
||||
if err := saveInitConfig(profileName, existing, f, result.AppID, core.SecretInput{}, result.Brand, opts.Lang, result.AuthMethod, keyRefFromResult(result)); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
removeStaleSecretForPKJWT(existing, profileName, result.AppID, f.Keychain)
|
||||
printLangPreferenceConfirmation(opts)
|
||||
output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "authMethod": result.AuthMethod, "brand": result.Brand})
|
||||
return runProbePKJWT(opts.Ctx, f, result.Brand, result.AppID, keysigner.Active(), result.KeyLabel)
|
||||
}
|
||||
|
||||
secret, err := core.ForStorage(result.AppID, core.PlainSecret(result.AppSecret), f.Keychain)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "%v", err).WithCause(err)
|
||||
}
|
||||
if err := saveInitConfig(profileName, existing, f, result.AppID, secret, result.Brand, opts.Lang, "", nil); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
printLangPreferenceConfirmation(opts)
|
||||
output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "appSecret": "****", "brand": result.Brand})
|
||||
return runProbe(opts.Ctx, f, result.AppID, result.AppSecret, result.Brand)
|
||||
}
|
||||
|
||||
// runRestoreFlow re-registers the app already in config to recover a lost
|
||||
// credential (deleted keychain key / lost app secret). It reads the existing
|
||||
// app id + auth method + brand from config (no secret needed — that's the lost
|
||||
// part) and re-runs the device-flow registration with the app id sent on begin,
|
||||
// so the server re-registers that app instead of creating a new one. The
|
||||
// re-issued credential is written back to the same profile.
|
||||
func runRestoreFlow(opts *ConfigInitOptions, existing *core.MultiAppConfig, f *cmdutil.Factory, msg *initMsg) error {
|
||||
if existing == nil {
|
||||
return errs.NewConfigError(errs.SubtypeNotConfigured, "nothing to restore: no config found").
|
||||
WithHint("run: lark-cli config init")
|
||||
}
|
||||
app := existing.CurrentAppConfig(opts.ProfileName)
|
||||
if app == nil || app.AppId == "" {
|
||||
return errs.NewConfigError(errs.SubtypeNotConfigured, "nothing to restore: no app id in config%s", profileSuffix(opts.ProfileName)).
|
||||
WithHint("run: lark-cli config init")
|
||||
}
|
||||
|
||||
restoreAppID := app.AppId
|
||||
// Reuse the stored auth method authoritatively — never prompt. Empty on disk
|
||||
// means client_secret (omitempty back-compat); pass it explicitly so
|
||||
// resolveRegisterAuthMethod doesn't fall through to the interactive picker.
|
||||
authMethod := app.AuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = core.AuthMethodClientSecret
|
||||
}
|
||||
result, err := runCreateAppFlow(opts.Ctx, f, app.Brand, authMethod, msg, restoreAppID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result == nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "app restore returned no result")
|
||||
}
|
||||
|
||||
// Safety: if the server did not honor app_id (e.g. not yet supported), it may
|
||||
// have created a NEW app instead of restoring. Warn so the user is not silently
|
||||
// switched to a different app id.
|
||||
if result.AppID != restoreAppID {
|
||||
fmt.Fprintf(f.IOStreams.ErrOut, "[lark-cli] [WARN] restore: server returned app %s, expected %s — it may have created a new app instead of restoring\n", result.AppID, restoreAppID)
|
||||
}
|
||||
|
||||
// Write back to the profile we restored: an explicit --name, else the resolved
|
||||
// app's own name. Empty name => legacy single-app replace.
|
||||
saveProfile := opts.ProfileName
|
||||
if saveProfile == "" {
|
||||
saveProfile = app.Name
|
||||
}
|
||||
return persistAndProbeResult(opts, f, saveProfile, result)
|
||||
}
|
||||
|
||||
// profileSuffix renders " (profile %q)" for error messages, or "" when unnamed.
|
||||
func profileSuffix(profileName string) string {
|
||||
if profileName == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(" (profile %q)", profileName)
|
||||
}
|
||||
|
||||
func configInitRun(opts *ConfigInitOptions) error {
|
||||
f := opts.Factory
|
||||
|
||||
@@ -335,6 +468,17 @@ func configInitRun(opts *ConfigInitOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
// --restore recovers an existing app; it is incompatible with creating a new
|
||||
// app (--new) or importing one non-interactively (--app-id / stdin secret).
|
||||
if opts.Restore {
|
||||
if opts.New {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--restore cannot be combined with --new").WithParam("--restore")
|
||||
}
|
||||
if opts.AppID != "" || opts.AppSecretStdin {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--restore cannot be combined with --app-id / --app-secret-stdin").WithParam("--restore")
|
||||
}
|
||||
}
|
||||
|
||||
// Mode 1: Non-interactive
|
||||
if opts.AppID != "" && opts.appSecret != "" {
|
||||
brand := parseBrand(opts.Brand)
|
||||
@@ -342,7 +486,7 @@ func configInitRun(opts *ConfigInitOptions) error {
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "%v", err).WithCause(err)
|
||||
}
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, opts.AppID, secret, brand, opts.Lang); err != nil {
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, opts.AppID, secret, brand, opts.Lang, "", nil); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath()))
|
||||
@@ -368,34 +512,26 @@ func configInitRun(opts *ConfigInitOptions) error {
|
||||
|
||||
msg := getInitMsg(opts.UILang)
|
||||
|
||||
// Mode: Restore (--restore) — re-register the app already in config.
|
||||
if opts.Restore {
|
||||
return runRestoreFlow(opts, existing, f, msg)
|
||||
}
|
||||
|
||||
// Mode 3: Create new app directly (--new)
|
||||
if opts.New {
|
||||
result, err := runCreateAppFlow(opts.Ctx, f, parseBrand(opts.Brand), msg)
|
||||
result, err := runCreateAppFlow(opts.Ctx, f, parseBrand(opts.Brand), opts.AuthMethod, msg, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result == nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "app creation returned no result")
|
||||
}
|
||||
existing, _ := core.LoadMultiAppConfig()
|
||||
secret, err := core.ForStorage(result.AppID, core.PlainSecret(result.AppSecret), f.Keychain)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "%v", err).WithCause(err)
|
||||
}
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, secret, result.Brand, opts.Lang); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
printLangPreferenceConfirmation(opts)
|
||||
output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "appSecret": "****", "brand": result.Brand})
|
||||
if err := runProbe(opts.Ctx, f, result.AppID, result.AppSecret, result.Brand); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return persistAndProbeResult(opts, f, opts.ProfileName, result)
|
||||
}
|
||||
|
||||
// Mode 4: Interactive TUI (terminal)
|
||||
if !opts.hasAnyNonInteractiveFlag() && f.IOStreams.IsTerminal {
|
||||
result, err := runInteractiveConfigInit(opts.Ctx, f, msg)
|
||||
result, err := runInteractiveConfigInit(opts.Ctx, f, opts.AuthMethod, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -406,13 +542,22 @@ func configInitRun(opts *ConfigInitOptions) error {
|
||||
|
||||
existing, _ := core.LoadMultiAppConfig()
|
||||
|
||||
if result.AppSecret != "" {
|
||||
if result.AuthMethod == core.AuthMethodPrivateKeyJWT {
|
||||
// Secretless create: persist auth method + TEE key ref, no secret.
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, core.SecretInput{}, result.Brand, opts.Lang, result.AuthMethod, keyRefFromResult(result)); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
removeStaleSecretForPKJWT(existing, opts.ProfileName, result.AppID, f.Keychain)
|
||||
if err := runProbePKJWT(opts.Ctx, f, result.Brand, result.AppID, keysigner.Active(), result.KeyLabel); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if result.AppSecret != "" {
|
||||
// New secret provided (either from "create" or "existing" with input)
|
||||
secret, err := core.ForStorage(result.AppID, core.PlainSecret(result.AppSecret), f.Keychain)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "%v", err).WithCause(err)
|
||||
}
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, secret, result.Brand, opts.Lang); err != nil {
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, secret, result.Brand, opts.Lang, "", nil); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
} else if result.Mode == "existing" && result.AppID != "" {
|
||||
@@ -517,7 +662,7 @@ func configInitRun(opts *ConfigInitOptions) error {
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeSDKError, "%v", err).WithCause(err)
|
||||
}
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, resolvedAppId, storedSecret, parseBrand(resolvedBrand), opts.Lang); err != nil {
|
||||
if err := saveInitConfig(opts.ProfileName, existing, f, resolvedAppId, storedSecret, parseBrand(resolvedBrand), opts.Lang, "", nil); err != nil {
|
||||
return wrapSaveConfigError(err)
|
||||
}
|
||||
output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath()))
|
||||
|
||||
102
cmd/config/init_auth_method_test.go
Normal file
102
cmd/config/init_auth_method_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
type authMethodTestSigner struct{}
|
||||
|
||||
func (authMethodTestSigner) EnsureKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (authMethodTestSigner) PublicKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (authMethodTestSigner) Sign(context.Context, keysigner.KeyRef, []byte) ([]byte, string, error) {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// TestResolveRegisterAuthMethod covers the non-interactive gating paths. The
|
||||
// darwin keychain signer is compiled into every build, so the test cannot rely
|
||||
// on the binary lacking a signer — it forces a known no-signer state for the
|
||||
// rejection cases, then registers a stub for the success case.
|
||||
func TestResolveRegisterAuthMethod(t *testing.T) {
|
||||
f := &cmdutil.Factory{}
|
||||
|
||||
prevSigner := keysigner.Active()
|
||||
t.Cleanup(func() { keysigner.Register(prevSigner) })
|
||||
keysigner.Register(nil)
|
||||
|
||||
if m, err := resolveRegisterAuthMethod(f, core.AuthMethodClientSecret); err != nil || m != core.AuthMethodClientSecret {
|
||||
t.Errorf("client_secret: got (%q, %v), want (client_secret, nil)", m, err)
|
||||
}
|
||||
|
||||
if m, err := resolveRegisterAuthMethod(f, ""); err != nil || m != core.AuthMethodClientSecret {
|
||||
t.Errorf("default: got (%q, %v), want (client_secret, nil)", m, err)
|
||||
}
|
||||
|
||||
if _, err := resolveRegisterAuthMethod(f, "bogus"); err == nil {
|
||||
t.Error("bogus auth-method: expected error")
|
||||
}
|
||||
|
||||
if _, err := resolveRegisterAuthMethod(f, core.AuthMethodPrivateKeyJWT); err == nil {
|
||||
t.Error("private_key_jwt without a signer: expected error")
|
||||
}
|
||||
|
||||
keysigner.Register(authMethodTestSigner{})
|
||||
|
||||
if m, err := resolveRegisterAuthMethod(f, core.AuthMethodPrivateKeyJWT); err != nil || m != core.AuthMethodPrivateKeyJWT {
|
||||
t.Errorf("private_key_jwt with signer: got (%q, %v), want (private_key_jwt, nil)", m, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidatePKJWTKeyBinding covers the guard that rejects a registration
|
||||
// resolving to private_key_jwt with no signing key bound (e.g. an existing
|
||||
// secret-based app was selected on the confirm page).
|
||||
func TestValidatePKJWTKeyBinding(t *testing.T) {
|
||||
if err := validatePKJWTKeyBinding(core.AuthMethodPrivateKeyJWT, ""); err == nil {
|
||||
t.Error("pkjwt with empty keyLabel: expected error")
|
||||
}
|
||||
if err := validatePKJWTKeyBinding(core.AuthMethodPrivateKeyJWT, "agent-key"); err != nil {
|
||||
t.Errorf("pkjwt with keyLabel: expected nil, got %v", err)
|
||||
}
|
||||
if err := validatePKJWTKeyBinding(core.AuthMethodClientSecret, ""); err != nil {
|
||||
t.Errorf("client_secret: expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveFinalAuthMethod locks the authoritative-method logic. The 2nd case
|
||||
// is the real bug: we requested private_key_jwt but the server resolved to an
|
||||
// existing client_secret app — we must persist client_secret, not pkjwt.
|
||||
func TestResolveFinalAuthMethod(t *testing.T) {
|
||||
if m := resolveFinalAuthMethod([]string{"client_secret", "private_key_jwt"}, core.AuthMethodClientSecret); m != core.AuthMethodPrivateKeyJWT {
|
||||
t.Errorf("prefers private_key_jwt: got %q", m)
|
||||
}
|
||||
if m := resolveFinalAuthMethod([]string{"client_secret"}, core.AuthMethodPrivateKeyJWT); m != core.AuthMethodClientSecret {
|
||||
t.Errorf("server client_secret must override requested pkjwt: got %q", m)
|
||||
}
|
||||
if m := resolveFinalAuthMethod(nil, core.AuthMethodPrivateKeyJWT); m != core.AuthMethodPrivateKeyJWT {
|
||||
t.Errorf("fallback to requested when server is silent: got %q", m)
|
||||
}
|
||||
// Explicit empty slice (not just nil) also falls back to requested — the same
|
||||
// len()==0 back-compat allowance the init guard relies on to let private_key_jwt
|
||||
// proceed against an older server (see internal/auth
|
||||
// TestRequestAppRegistrationInit_EmptySupportedAuthMethods).
|
||||
if m := resolveFinalAuthMethod([]string{}, core.AuthMethodPrivateKeyJWT); m != core.AuthMethodPrivateKeyJWT {
|
||||
t.Errorf("empty []string should fall back to requested private_key_jwt: got %q", m)
|
||||
}
|
||||
if m := resolveFinalAuthMethod(nil, ""); m != core.AuthMethodClientSecret {
|
||||
t.Errorf("default to client_secret: got %q", m)
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,11 @@ package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/huh"
|
||||
"github.com/larksuite/cli/internal/build"
|
||||
@@ -13,22 +17,26 @@ import (
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
larkauth "github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/auth/jwt"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/transport"
|
||||
)
|
||||
|
||||
// configInitResult holds the result of the interactive config init flow.
|
||||
type configInitResult struct {
|
||||
Mode string // "create" or "existing"
|
||||
Brand core.LarkBrand
|
||||
AppID string
|
||||
AppSecret string
|
||||
Mode string // "create" or "existing"
|
||||
Brand core.LarkBrand
|
||||
AppID string
|
||||
AppSecret string
|
||||
AuthMethod string // "" == client_secret; core.AuthMethodPrivateKeyJWT
|
||||
KeyLabel string // TEE key handle when AuthMethod == private_key_jwt
|
||||
}
|
||||
|
||||
// runInteractiveConfigInit shows an interactive TUI for config init.
|
||||
func runInteractiveConfigInit(ctx context.Context, f *cmdutil.Factory, msg *initMsg) (*configInitResult, error) {
|
||||
func runInteractiveConfigInit(ctx context.Context, f *cmdutil.Factory, authMethodFlag string, msg *initMsg) (*configInitResult, error) {
|
||||
// Phase 1: Choose mode
|
||||
var mode string
|
||||
form1 := huh.NewForm(
|
||||
@@ -54,7 +62,7 @@ func runInteractiveConfigInit(ctx context.Context, f *cmdutil.Factory, msg *init
|
||||
return runExistingAppForm(f, msg)
|
||||
}
|
||||
|
||||
return runCreateAppFlow(ctx, f, "", msg)
|
||||
return runCreateAppFlow(ctx, f, "", authMethodFlag, msg, "")
|
||||
}
|
||||
|
||||
// runExistingAppForm shows a huh form for manually entering App ID / App Secret / Brand.
|
||||
@@ -146,9 +154,59 @@ func runExistingAppForm(f *cmdutil.Factory, msg *initMsg) (*configInitResult, er
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveRegisterAuthMethod decides the auth method for a new-app registration.
|
||||
// An explicit --auth-method flag wins; otherwise, on an interactive terminal with
|
||||
// a TEE signer available, the user is prompted; the default is client_secret.
|
||||
func resolveRegisterAuthMethod(f *cmdutil.Factory, flag string) (string, error) {
|
||||
signerAvailable := keysigner.Active() != nil
|
||||
switch flag {
|
||||
case core.AuthMethodPrivateKeyJWT:
|
||||
if !signerAvailable {
|
||||
return "", errs.NewConfigError(errs.SubtypeInvalidClient,
|
||||
"--auth-method private_key_jwt requires a platform key signer, which is unavailable on this device/build").
|
||||
WithHint("omit --auth-method (or pass --auth-method client_secret) to register with an app secret")
|
||||
}
|
||||
return core.AuthMethodPrivateKeyJWT, nil
|
||||
case core.AuthMethodClientSecret:
|
||||
return core.AuthMethodClientSecret, nil
|
||||
case "":
|
||||
// fall through to interactive / default
|
||||
default:
|
||||
return "", errs.NewValidationError(errs.SubtypeInvalidArgument,
|
||||
"unknown --auth-method %q (use client_secret or private_key_jwt)", flag)
|
||||
}
|
||||
|
||||
if signerAvailable && f.IOStreams.IsTerminal {
|
||||
var choice string
|
||||
form := huh.NewForm(
|
||||
huh.NewGroup(
|
||||
huh.NewSelect[string]().
|
||||
Title("Authentication method").
|
||||
Options(
|
||||
huh.NewOption("App Secret (client_secret)", core.AuthMethodClientSecret),
|
||||
huh.NewOption("Secure key signer, no secret (private_key_jwt)", core.AuthMethodPrivateKeyJWT),
|
||||
).
|
||||
Value(&choice),
|
||||
),
|
||||
).WithTheme(cmdutil.ThemeFeishu())
|
||||
if err := form.Run(); err != nil {
|
||||
if errors.Is(err, huh.ErrUserAborted) {
|
||||
return "", output.ErrBare(1)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return choice, nil
|
||||
}
|
||||
return core.AuthMethodClientSecret, nil
|
||||
}
|
||||
|
||||
// runCreateAppFlow runs the "create new app" flow via OpenClaw device flow.
|
||||
// If brandOverride is non-empty, skip the interactive brand selection.
|
||||
func runCreateAppFlow(ctx context.Context, f *cmdutil.Factory, brandOverride core.LarkBrand, msg *initMsg) (*configInitResult, error) {
|
||||
// authMethodFlag is the raw --auth-method value ("" when unset).
|
||||
// restoreAppID, when non-empty, is sent on the registration begin request so the
|
||||
// server re-registers that existing app (credential recovery) instead of creating
|
||||
// a new one. Empty preserves the normal new-app flow.
|
||||
func runCreateAppFlow(ctx context.Context, f *cmdutil.Factory, brandOverride core.LarkBrand, authMethodFlag string, msg *initMsg, restoreAppID string) (*configInitResult, error) {
|
||||
var larkBrand core.LarkBrand
|
||||
if brandOverride != "" {
|
||||
larkBrand = brandOverride
|
||||
@@ -176,11 +234,51 @@ func runCreateAppFlow(ctx context.Context, f *cmdutil.Factory, brandOverride cor
|
||||
larkBrand = parseBrand(brand)
|
||||
}
|
||||
|
||||
// Step 1: Request app registration (begin)
|
||||
authMethod, err := resolveRegisterAuthMethod(f, authMethodFlag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Step 1: Request app registration (begin).
|
||||
// Use the shared proxy-plugin-aware transport so registration traffic is not
|
||||
// a bypass of proxy plugin mode.
|
||||
httpClient := transport.NewHTTPClient(0)
|
||||
authResp, err := larkauth.RequestAppRegistration(httpClient, larkBrand, f.IOStreams.ErrOut)
|
||||
|
||||
// For private_key_jwt: init to obtain a nonce, then sign a TEE attestation
|
||||
// (carrying the public key in its jwk header) to send with begin.
|
||||
beginOpts := larkauth.AppRegistrationBeginOptions{}
|
||||
keyLabel := ""
|
||||
if authMethod == core.AuthMethodPrivateKeyJWT {
|
||||
signer := keysigner.Active() // non-nil, guaranteed by resolveRegisterAuthMethod
|
||||
initResp, initErr := larkauth.RequestAppRegistrationInit(httpClient)
|
||||
if initErr != nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "app registration init failed: %v", initErr).WithCause(initErr)
|
||||
}
|
||||
// An empty SupportedAuthMethods is intentionally treated as "older server /
|
||||
// unknown": len()==0 makes this guard false, so the requested
|
||||
// private_key_jwt proceeds. This mirrors resolveFinalAuthMethod's
|
||||
// back-compat fallback to the requested method. Only an explicit list that
|
||||
// omits private_key_jwt rejects here.
|
||||
if len(initResp.SupportedAuthMethods) > 0 && !slices.Contains(initResp.SupportedAuthMethods, core.AuthMethodPrivateKeyJWT) {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient,
|
||||
"server does not support private_key_jwt for this app type (supported: %s)", strings.Join(initResp.SupportedAuthMethods, ", ")).
|
||||
WithHint("register with --auth-method client_secret instead")
|
||||
}
|
||||
keyLabel = keysigner.DefaultKeyLabel
|
||||
attestation, signErr := jwt.SignAttestation(ctx, signer, keysigner.KeyRef{Label: keyLabel}, initResp.Nonce, time.Now())
|
||||
if signErr != nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "failed to sign registration attestation: %v", signErr).WithCause(signErr)
|
||||
}
|
||||
beginOpts = larkauth.AppRegistrationBeginOptions{
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
AuthAttestation: attestation,
|
||||
}
|
||||
}
|
||||
|
||||
// Restore flow: re-register the existing app instead of creating a new one.
|
||||
beginOpts.RestoreAppID = restoreAppID
|
||||
|
||||
authResp, err := larkauth.RequestAppRegistration(httpClient, larkBrand, beginOpts, f.IOStreams.ErrOut)
|
||||
if err != nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "app registration failed: %v", err).WithCause(err)
|
||||
}
|
||||
@@ -213,18 +311,28 @@ func runCreateAppFlow(ctx context.Context, f *cmdutil.Factory, brandOverride cor
|
||||
return nil, errs.NewAuthenticationError(errs.SubtypeUnknown, "%v", err).WithCause(err)
|
||||
}
|
||||
|
||||
// Step 4: Handle Lark brand special case
|
||||
// If tenant_brand=lark and no client_secret, retry with lark brand endpoint
|
||||
if result.ClientSecret == "" && result.UserInfo != nil && result.UserInfo.TenantBrand == "lark" {
|
||||
// fmt.Fprintf(f.IOStreams.ErrOut, "%s\n", msg.DetectedLarkTenant)
|
||||
// The final auth method is decided by the user/admin at confirmation and
|
||||
// returned by poll — NOT necessarily what we requested. Selecting an existing
|
||||
// client_secret app, for example, yields client_secret even though we sent
|
||||
// private_key_jwt. Trust the result so we persist the truth.
|
||||
finalMethod := resolveFinalAuthMethod(result.AuthMethods, authMethod)
|
||||
|
||||
// Lark brand special case (client_secret only): a lark-tenant app returns its
|
||||
// secret only from the lark endpoint. private_key_jwt returns no secret, so
|
||||
// this retry does not apply.
|
||||
if finalMethod != core.AuthMethodPrivateKeyJWT && result.ClientSecret == "" && result.UserInfo != nil && result.UserInfo.TenantBrand == "lark" {
|
||||
result, err = larkauth.PollAppRegistration(ctx, httpClient, core.BrandLark, authResp.DeviceCode, authResp.Interval, authResp.ExpiresIn, f.IOStreams.ErrOut)
|
||||
if err != nil {
|
||||
return nil, errs.NewNetworkError(errs.SubtypeNetworkTransport, "lark endpoint retry failed: %v", err).WithCause(err)
|
||||
}
|
||||
finalMethod = resolveFinalAuthMethod(result.AuthMethods, authMethod)
|
||||
}
|
||||
|
||||
if result.ClientID == "" || result.ClientSecret == "" {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "app registration succeeded but missing client_id or client_secret")
|
||||
if result.ClientID == "" {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "app registration succeeded but missing client_id")
|
||||
}
|
||||
if finalMethod != core.AuthMethodPrivateKeyJWT && result.ClientSecret == "" {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient, "app registration succeeded but missing client_secret")
|
||||
}
|
||||
|
||||
// Determine final brand from response
|
||||
@@ -235,13 +343,67 @@ func runCreateAppFlow(ctx context.Context, f *cmdutil.Factory, brandOverride cor
|
||||
finalBrand = core.BrandFeishu
|
||||
}
|
||||
|
||||
// Surface a downgrade: requested private_key_jwt but the app resolved to a
|
||||
// secret-based method (e.g. an existing app was selected). The key was NOT
|
||||
// bound, so we must store the secret method, not private_key_jwt.
|
||||
if authMethod == core.AuthMethodPrivateKeyJWT && finalMethod != core.AuthMethodPrivateKeyJWT {
|
||||
fmt.Fprintf(f.IOStreams.ErrOut, "[lark-cli] note: requested private_key_jwt, but the app uses %q (e.g. an existing app was selected); storing %q.\n", finalMethod, finalMethod)
|
||||
}
|
||||
|
||||
fmt.Fprintln(f.IOStreams.ErrOut)
|
||||
output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf(msg.AppCreated, result.ClientID))
|
||||
|
||||
keyToStore := ""
|
||||
if finalMethod == core.AuthMethodPrivateKeyJWT {
|
||||
keyToStore = keyLabel
|
||||
}
|
||||
if err := validatePKJWTKeyBinding(finalMethod, keyToStore); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &configInitResult{
|
||||
Mode: "create",
|
||||
Brand: finalBrand,
|
||||
AppID: result.ClientID,
|
||||
AppSecret: result.ClientSecret,
|
||||
Mode: "create",
|
||||
Brand: finalBrand,
|
||||
AppID: result.ClientID,
|
||||
AppSecret: result.ClientSecret, // empty for private_key_jwt; real secret otherwise
|
||||
AuthMethod: finalMethod,
|
||||
KeyLabel: keyToStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validatePKJWTKeyBinding rejects a registration that resolved to
|
||||
// private_key_jwt without a signing key bound to it. keyLabel is non-empty only
|
||||
// when the local flow chose private_key_jwt and signed a TEE attestation; a
|
||||
// resolved method of private_key_jwt with no key handle would save an unusable
|
||||
// config (rejected later at config load, surfacing as "saved OK, fails on first
|
||||
// use"), so it is caught here at registration time instead.
|
||||
func validatePKJWTKeyBinding(finalMethod, keyLabel string) error {
|
||||
if finalMethod == core.AuthMethodPrivateKeyJWT && keyLabel == "" {
|
||||
return errs.NewConfigError(errs.SubtypeInvalidClient,
|
||||
"registration resolved to private_key_jwt but no signing key was bound to this app (an existing secret-based app may have been selected)").
|
||||
WithHint("re-register with: lark-cli config init --new --auth-method private_key_jwt")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveFinalAuthMethod picks the authoritative method from the poll result,
|
||||
// preferring private_key_jwt, then client_secret. It falls back to the requested
|
||||
// method when the server returns nothing (older servers).
|
||||
func resolveFinalAuthMethod(serverMethods []string, requested string) string {
|
||||
if len(serverMethods) == 0 {
|
||||
if requested == "" {
|
||||
return core.AuthMethodClientSecret
|
||||
}
|
||||
return requested
|
||||
}
|
||||
for _, m := range serverMethods {
|
||||
if m == core.AuthMethodPrivateKeyJWT {
|
||||
return core.AuthMethodPrivateKeyJWT
|
||||
}
|
||||
}
|
||||
for _, m := range serverMethods {
|
||||
if m == core.AuthMethodClientSecret {
|
||||
return core.AuthMethodClientSecret
|
||||
}
|
||||
}
|
||||
return serverMethods[0]
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// probeTimeout is the total wall-clock budget for the credential probe step
|
||||
@@ -90,3 +91,32 @@ func runProbe(parent context.Context, factory *cmdutil.Factory, appID, appSecret
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// runProbePKJWT does a best-effort key-binding validation after a private_key_jwt
|
||||
// config is saved: it signs a client_assertion with the local platform key and
|
||||
// mints a token. A typed error (a deterministic server rejection — e.g. the key
|
||||
// is not bound to this app) is propagated so `config init` exits non-zero with
|
||||
// the canonical envelope; untyped errors (transport / HTTP / parse / timeout)
|
||||
// are swallowed (return nil). The mint itself is the probe — no second call.
|
||||
func runProbePKJWT(parent context.Context, factory *cmdutil.Factory, brand core.LarkBrand, clientID string, signer keysigner.Signer, keyLabel string) error {
|
||||
if factory == nil || signer == nil {
|
||||
return nil
|
||||
}
|
||||
httpClient, err := factory.HttpClient()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
if _, err := credential.FetchTATWithAssertion(ctx, httpClient, brand, clientID, signer, keyLabel); err != nil {
|
||||
// Typed = deterministic credential rejection → propagate. Untyped
|
||||
// (transport / HTTP / parse / timeout) is ambiguous → stay silent.
|
||||
if errs.IsTyped(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,11 @@ package config
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
crand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -17,14 +22,17 @@ import (
|
||||
"github.com/larksuite/cli/internal/build"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// fakeRT routes requests to per-path handlers and records what it saw.
|
||||
type fakeRT struct {
|
||||
tatHandler func(req *http.Request) (*http.Response, error)
|
||||
probeHandler func(req *http.Request) (*http.Response, error)
|
||||
oauthHandler func(req *http.Request) (*http.Response, error)
|
||||
tatCalls int
|
||||
probeCalls int
|
||||
oauthCalls int
|
||||
probeReq *http.Request
|
||||
probeBody string
|
||||
}
|
||||
@@ -48,10 +56,50 @@ func (f *fakeRT) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return jsonResp(200, `{"code":0,"data":{},"msg":"success"}`), nil
|
||||
}
|
||||
return f.probeHandler(req)
|
||||
case strings.HasSuffix(req.URL.Path, "/authen/v2/oauth/token"):
|
||||
f.oauthCalls++
|
||||
if f.oauthHandler == nil {
|
||||
return jsonResp(200, `{"access_token":"t-jwt"}`), nil
|
||||
}
|
||||
return f.oauthHandler(req)
|
||||
}
|
||||
return nil, errors.New("unexpected URL: " + req.URL.String())
|
||||
}
|
||||
|
||||
// probeTestSigner is an in-memory real ECDSA P-256 signer used to sign the
|
||||
// client_assertion in runProbePKJWT tests (authMethodTestSigner returns a nil
|
||||
// key and cannot sign).
|
||||
type probeTestSigner struct{ key *ecdsa.PrivateKey }
|
||||
|
||||
func newProbeTestSigner(t *testing.T) *probeTestSigner {
|
||||
t.Helper()
|
||||
k, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &probeTestSigner{key: k}
|
||||
}
|
||||
|
||||
func (p *probeTestSigner) EnsureKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return p.key.Public(), nil
|
||||
}
|
||||
|
||||
func (p *probeTestSigner) PublicKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return p.key.Public(), nil
|
||||
}
|
||||
|
||||
func (p *probeTestSigner) Sign(_ context.Context, _ keysigner.KeyRef, in []byte) ([]byte, string, error) {
|
||||
h := sha256.Sum256(in)
|
||||
r, s, err := ecdsa.Sign(crand.Reader, p.key, h[:])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
sig := make([]byte, 64)
|
||||
r.FillBytes(sig[:32])
|
||||
s.FillBytes(sig[32:])
|
||||
return sig, keysigner.AlgES256, nil
|
||||
}
|
||||
|
||||
func jsonResp(code int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: code,
|
||||
@@ -285,3 +333,42 @@ func TestRunProbe_TimeoutHonored(t *testing.T) {
|
||||
// must stay silent and not block.
|
||||
assertSilent(t, err, errBuf)
|
||||
}
|
||||
|
||||
// runProbePKJWT: a deterministic server rejection (invalid_client) is propagated
|
||||
// as a typed ConfigError so config init exits non-zero.
|
||||
func TestRunProbePKJWT_DeterministicReject_Propagates(t *testing.T) {
|
||||
rt := &fakeRT{oauthHandler: func(*http.Request) (*http.Response, error) {
|
||||
return jsonResp(401, `{"error":"invalid_client","error_description":"unknown key"}`), nil
|
||||
}}
|
||||
f, errBuf := fakeFactory(t, rt)
|
||||
err := runProbePKJWT(context.Background(), f, core.BrandFeishu, "cli_x", newProbeTestSigner(t), "agent-key")
|
||||
if err == nil || !errs.IsTyped(err) {
|
||||
t.Fatalf("expected propagated typed error, got %T %v", err, err)
|
||||
}
|
||||
if errBuf.Len() != 0 {
|
||||
t.Errorf("runProbePKJWT must not write stderr, got %q", errBuf.String())
|
||||
}
|
||||
}
|
||||
|
||||
// runProbePKJWT: ambiguous upstream noise (HTTP 503) is swallowed — silent, exit 0.
|
||||
func TestRunProbePKJWT_Ambiguous_Silent(t *testing.T) {
|
||||
rt := &fakeRT{oauthHandler: func(*http.Request) (*http.Response, error) {
|
||||
return jsonResp(503, `unavailable`), nil
|
||||
}}
|
||||
f, errBuf := fakeFactory(t, rt)
|
||||
assertSilent(t, runProbePKJWT(context.Background(), f, core.BrandFeishu, "cli_x", newProbeTestSigner(t), "agent-key"), errBuf)
|
||||
}
|
||||
|
||||
// runProbePKJWT: a successful mint returns nil.
|
||||
func TestRunProbePKJWT_Success_Silent(t *testing.T) {
|
||||
rt := &fakeRT{} // default oauth handler returns 200 + access_token
|
||||
f, errBuf := fakeFactory(t, rt)
|
||||
assertSilent(t, runProbePKJWT(context.Background(), f, core.BrandFeishu, "cli_x", newProbeTestSigner(t), "agent-key"), errBuf)
|
||||
}
|
||||
|
||||
// runProbePKJWT: a nil signer is a defensive no-op (should not be reached, must
|
||||
// not panic).
|
||||
func TestRunProbePKJWT_NilSigner_Silent(t *testing.T) {
|
||||
f, errBuf := fakeFactory(t, &fakeRT{})
|
||||
assertSilent(t, runProbePKJWT(context.Background(), f, core.BrandFeishu, "cli_x", nil, "k"), errBuf)
|
||||
}
|
||||
|
||||
@@ -10,9 +10,25 @@ import (
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keychain"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
)
|
||||
|
||||
// TestRunRestoreFlow_NothingToRestore covers the early guards that return before
|
||||
// any network/registration call: no config at all, and a config whose resolved
|
||||
// app has no app id (nothing to send on begin).
|
||||
func TestRunRestoreFlow_NothingToRestore(t *testing.T) {
|
||||
// No config on disk.
|
||||
if err := runRestoreFlow(&ConfigInitOptions{}, nil, nil, nil); err == nil {
|
||||
t.Fatal("expected error when there is no config to restore")
|
||||
}
|
||||
// Config present but the resolved app has no app id.
|
||||
existing := &core.MultiAppConfig{Apps: []core.AppConfig{{AppId: ""}}}
|
||||
if err := runRestoreFlow(&ConfigInitOptions{}, existing, nil, nil); err == nil {
|
||||
t.Fatal("expected error when the resolved app has no app id")
|
||||
}
|
||||
}
|
||||
|
||||
// updateExistingProfileWithoutSecret guards four blank-input scenarios. Each
|
||||
// must surface as *ValidationError(SubtypeInvalidArgument) per RFC 6749 §5.2:
|
||||
// SubtypeInvalidClient is reserved for IAM rejection of malformed credentials,
|
||||
@@ -119,3 +135,62 @@ func assertValidationParam(t *testing.T, err error, wantParam string) {
|
||||
t.Errorf("Param = %q, want %q", valErr.Param, wantParam)
|
||||
}
|
||||
}
|
||||
|
||||
// countingKeychain is an in-memory KeychainAccess that records whether Remove
|
||||
// was invoked, so the stale-secret cleanup can be asserted without a real OS
|
||||
// keychain.
|
||||
type countingKeychain struct {
|
||||
store map[string]string
|
||||
removeCalled bool
|
||||
}
|
||||
|
||||
func newCountingKeychain() *countingKeychain {
|
||||
return &countingKeychain{store: map[string]string{}}
|
||||
}
|
||||
|
||||
func (k *countingKeychain) Get(service, account string) (string, error) {
|
||||
v, ok := k.store[service+"/"+account]
|
||||
if !ok {
|
||||
return "", keychain.ErrNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (k *countingKeychain) Set(service, account, value string) error {
|
||||
k.store[service+"/"+account] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *countingKeychain) Remove(service, account string) error {
|
||||
k.removeCalled = true
|
||||
delete(k.store, service+"/"+account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRemoveStaleSecretForPKJWT_SameAppID(t *testing.T) {
|
||||
kc := newCountingKeychain()
|
||||
ref, err := core.ForStorage("cli_same", core.PlainSecret("old-secret"), kc) // → Source:"keychain"
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
existing := &core.MultiAppConfig{Apps: []core.AppConfig{{AppId: "cli_same", AppSecret: ref}}}
|
||||
removeStaleSecretForPKJWT(existing, "", "cli_same", kc)
|
||||
if !kc.removeCalled {
|
||||
t.Error("same appId with keychain secret: expected kc.Remove to be invoked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveStaleSecretForPKJWT_DifferentAppID(t *testing.T) {
|
||||
kc := newCountingKeychain()
|
||||
ref, _ := core.ForStorage("cli_old", core.PlainSecret("old-secret"), kc)
|
||||
kc.removeCalled = false // ForStorage does not call Remove, but reset to be safe
|
||||
existing := &core.MultiAppConfig{Apps: []core.AppConfig{{AppId: "cli_old", AppSecret: ref}}}
|
||||
removeStaleSecretForPKJWT(existing, "", "cli_new", kc)
|
||||
if kc.removeCalled {
|
||||
t.Error("different appId: must NOT remove")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveStaleSecretForPKJWT_NilExisting(t *testing.T) {
|
||||
removeStaleSecretForPKJWT(nil, "", "cli_x", newCountingKeychain()) // must not panic
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/identitydiag"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/transport"
|
||||
"github.com/larksuite/cli/internal/update"
|
||||
@@ -132,6 +134,9 @@ func doctorRun(opts *DoctorOptions) error {
|
||||
checks = append(checks, fail("identity_ready", "no usable bot or user identity is available", "run: lark-cli auth status --verify"))
|
||||
}
|
||||
|
||||
// ── 3b. private_key_jwt / TEE signer (local; runs even with --offline) ──
|
||||
checks = append(checks, teeSignerCheck(opts.Ctx, cfg))
|
||||
|
||||
// ── 4 & 5. Endpoint reachability ──
|
||||
checks = append(checks, networkChecks(opts.Ctx, opts, ep)...)
|
||||
|
||||
@@ -145,6 +150,54 @@ func identityCheck(name string, id identitydiag.Identity) checkResult {
|
||||
return warn(name, id.Message, id.Hint)
|
||||
}
|
||||
|
||||
const teeUnavailableHint = "ensure the device secure hardware is accessible (Linux TPM: add your user to the 'tss' group or run with sufficient privileges)"
|
||||
|
||||
// teeSignerCheck reports the private_key_jwt signing backend (TEE/TPM) status.
|
||||
// The probe is local hardware only (no network), so it runs even with --offline;
|
||||
// in a build without a TEE signer it short-circuits without touching any
|
||||
// hardware. It is a hard requirement for private_key_jwt apps and purely
|
||||
// informational for client_secret apps.
|
||||
func teeSignerCheck(ctx context.Context, cfg *core.CliConfig) checkResult {
|
||||
usesPKJWT := cfg != nil && cfg.AuthMethod == core.AuthMethodPrivateKeyJWT
|
||||
info, ok, err := keysigner.ProbeActiveHardware(ctx)
|
||||
return teeCheckResult(info, ok, err, usesPKJWT)
|
||||
}
|
||||
|
||||
// teeCheckResult maps a hardware probe to a doctor check. Split out from
|
||||
// teeSignerCheck so the full matrix is unit-testable without a TPM.
|
||||
func teeCheckResult(info keysigner.HardwareInfo, ok bool, probeErr error, usesPKJWT bool) checkResult {
|
||||
const name = "tee_signer"
|
||||
|
||||
// No signer registered → private_key_jwt is unsupported on this build.
|
||||
if !ok {
|
||||
if usesPKJWT {
|
||||
return fail(name,
|
||||
"app uses private_key_jwt but this build has no TEE key signer",
|
||||
"the platform key signer ships by default on macOS, Linux, and Windows/amd64; this platform (e.g. Windows/arm64) has none — use a supported platform or re-register with --auth-method client_secret")
|
||||
}
|
||||
return skip(name, "no TEE signer in this build (only private_key_jwt is affected; client_secret is unaffected)")
|
||||
}
|
||||
|
||||
backend := info.Backend
|
||||
if backend == "" {
|
||||
backend = "tee"
|
||||
}
|
||||
|
||||
switch {
|
||||
case probeErr != nil:
|
||||
return warn(name, fmt.Sprintf("%s signer present but probe errored: %s", backend, probeErr), "")
|
||||
case info.Available:
|
||||
if info.VendorName != "" {
|
||||
return pass(name, fmt.Sprintf("%s TEE available (%s)", backend, info.VendorName))
|
||||
}
|
||||
return pass(name, fmt.Sprintf("%s TEE available", backend))
|
||||
case usesPKJWT:
|
||||
return fail(name, fmt.Sprintf("%s signer present but TEE unavailable: %s", backend, info.Reason), teeUnavailableHint)
|
||||
default:
|
||||
return warn(name, fmt.Sprintf("%s signer present but TEE unavailable: %s", backend, info.Reason), teeUnavailableHint)
|
||||
}
|
||||
}
|
||||
|
||||
// networkChecks probes Open API and MCP endpoints concurrently.
|
||||
func networkChecks(ctx context.Context, opts *DoctorOptions, ep core.Endpoints) []checkResult {
|
||||
if opts.Offline {
|
||||
@@ -234,14 +287,90 @@ func finishDoctor(f *cmdutil.Factory, checks []checkResult) error {
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"ok": allOK,
|
||||
"workspace": core.CurrentWorkspace().Display(),
|
||||
"checks": checks,
|
||||
workspace := core.CurrentWorkspace().Display()
|
||||
// A terminal on STDOUT gets a readable report; pipes, redirects, scripts and
|
||||
// tests keep the stable JSON contract (NO_COLOR disables ANSI styling).
|
||||
// StdoutIsTerminal checks stdout specifically — IOStreams.IsTerminal reflects
|
||||
// stdin, which would wrongly send the human report into `doctor | jq`.
|
||||
if f.IOStreams.StdoutIsTerminal() {
|
||||
renderDoctorHuman(f.IOStreams.Out, workspace, checks, allOK, os.Getenv("NO_COLOR") == "")
|
||||
} else {
|
||||
output.PrintJson(f.IOStreams.Out, map[string]interface{}{
|
||||
"ok": allOK,
|
||||
"workspace": workspace,
|
||||
"checks": checks,
|
||||
})
|
||||
}
|
||||
output.PrintJson(f.IOStreams.Out, result)
|
||||
if !allOK {
|
||||
return output.ErrBare(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// renderDoctorHuman writes a readable health report: one aligned line per check
|
||||
// with a colored status tag, an indented hint when present, and a summary line.
|
||||
func renderDoctorHuman(w io.Writer, workspace string, checks []checkResult, allOK, color bool) {
|
||||
const (
|
||||
green = "\033[32m"
|
||||
yellow = "\033[33m"
|
||||
red = "\033[31m"
|
||||
gray = "\033[90m"
|
||||
bold = "\033[1m"
|
||||
reset = "\033[0m"
|
||||
)
|
||||
colorOf := map[string]string{"pass": green, "warn": yellow, "fail": red, "skip": gray}
|
||||
tagOf := map[string]string{"pass": "PASS", "warn": "WARN", "fail": "FAIL", "skip": "SKIP"}
|
||||
paint := func(code, s string) string {
|
||||
if !color || code == "" {
|
||||
return s
|
||||
}
|
||||
return code + s + reset
|
||||
}
|
||||
|
||||
nameW := 0
|
||||
for _, c := range checks {
|
||||
if len(c.Name) > nameW {
|
||||
nameW = len(c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\n%s (workspace: %s)\n\n", paint(bold, "lark-cli doctor"), workspace)
|
||||
|
||||
var passN, warnN, failN, skipN int
|
||||
for _, c := range checks {
|
||||
tag := tagOf[c.Status]
|
||||
if tag == "" {
|
||||
tag = "????"
|
||||
}
|
||||
fmt.Fprintf(w, " %s %-*s %s\n", paint(colorOf[c.Status], "["+tag+"]"), nameW, c.Name, c.Message)
|
||||
if c.Hint != "" {
|
||||
fmt.Fprintf(w, " %-*s %s\n", nameW, "", paint(gray, "↳ "+c.Hint))
|
||||
}
|
||||
switch c.Status {
|
||||
case "pass":
|
||||
passN++
|
||||
case "warn":
|
||||
warnN++
|
||||
case "fail":
|
||||
failN++
|
||||
case "skip":
|
||||
skipN++
|
||||
}
|
||||
}
|
||||
|
||||
headline := paint(green, "healthy")
|
||||
if !allOK {
|
||||
headline = paint(red, "problems found")
|
||||
}
|
||||
fmt.Fprintf(w, "\n %s — %d passed", headline, passN)
|
||||
if warnN > 0 {
|
||||
fmt.Fprintf(w, ", %d warning(s)", warnN)
|
||||
}
|
||||
if failN > 0 {
|
||||
fmt.Fprintf(w, ", %d failed", failN)
|
||||
}
|
||||
if skipN > 0 {
|
||||
fmt.Fprintf(w, ", %d skipped", skipN)
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
@@ -4,14 +4,18 @@
|
||||
package doctor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
func TestNewCmdDoctor_FlagParsing(t *testing.T) {
|
||||
@@ -139,6 +143,107 @@ func TestDoctorRun_SplitsBotAndMissingUserIdentity(t *testing.T) {
|
||||
assertCheck(t, got.Checks, "identity_ready", "pass")
|
||||
}
|
||||
|
||||
func TestTeeCheckResult(t *testing.T) {
|
||||
avail := keysigner.HardwareInfo{Backend: "tpm2", Available: true, VendorName: "ACME"}
|
||||
unavail := keysigner.HardwareInfo{Backend: "tpm2", Reason: "open /dev/tpmrm0: permission denied"}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
info keysigner.HardwareInfo
|
||||
ok bool
|
||||
probeErr error
|
||||
pkjwt bool
|
||||
want string
|
||||
}{
|
||||
{"no signer + private_key_jwt → fail", keysigner.HardwareInfo{}, false, nil, true, "fail"},
|
||||
{"no signer + client_secret → skip", keysigner.HardwareInfo{}, false, nil, false, "skip"},
|
||||
{"available + private_key_jwt → pass", avail, true, nil, true, "pass"},
|
||||
{"available + client_secret → pass", avail, true, nil, false, "pass"},
|
||||
{"unavailable + private_key_jwt → fail", unavail, true, nil, true, "fail"},
|
||||
{"unavailable + client_secret → warn", unavail, true, nil, false, "warn"},
|
||||
{"probe error → warn", keysigner.HardwareInfo{Backend: "tpm2"}, true, errors.New("boom"), true, "warn"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := teeCheckResult(tc.info, tc.ok, tc.probeErr, tc.pkjwt)
|
||||
if got.Name != "tee_signer" {
|
||||
t.Errorf("name = %q, want tee_signer", got.Name)
|
||||
}
|
||||
if got.Status != tc.want {
|
||||
t.Errorf("status = %q, want %q (msg=%q)", got.Status, tc.want, got.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoctorRun_TeeSignerWired proves the tee_signer check is part of doctorRun.
|
||||
// It asserts the build-independent invariant (a client_secret app must never
|
||||
// FAIL on TEE) so the test passes whether or not a signer is compiled in.
|
||||
func TestDoctorRun_TeeSignerWired(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
if err := core.SaveMultiAppConfig(&core.MultiAppConfig{
|
||||
CurrentApp: "default",
|
||||
Apps: []core.AppConfig{{
|
||||
Name: "default", AppId: "test-app",
|
||||
AppSecret: core.PlainSecret("secret"), Brand: core.BrandFeishu,
|
||||
}},
|
||||
}); err != nil {
|
||||
t.Fatalf("SaveMultiAppConfig() error = %v", err)
|
||||
}
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{
|
||||
AppID: "test-app", AppSecret: "secret", Brand: core.BrandFeishu,
|
||||
})
|
||||
if err := doctorRun(&DoctorOptions{Factory: f, Ctx: context.Background(), Offline: true}); err != nil {
|
||||
t.Fatalf("doctorRun() error = %v", err)
|
||||
}
|
||||
var got struct {
|
||||
Checks []checkResult `json:"checks"`
|
||||
}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
var c *checkResult
|
||||
for i := range got.Checks {
|
||||
if got.Checks[i].Name == "tee_signer" {
|
||||
c = &got.Checks[i]
|
||||
}
|
||||
}
|
||||
if c == nil {
|
||||
t.Fatalf("tee_signer check not present in doctor output: %#v", got.Checks)
|
||||
}
|
||||
if c.Status == "fail" {
|
||||
t.Errorf("tee_signer = fail for a client_secret app; want skip/warn/pass (msg=%q)", c.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderDoctorHuman(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
checks := []checkResult{
|
||||
pass("cli_version", "1.0.50"),
|
||||
warn("tee_signer", "tpm2 signer present but TEE unavailable", "add your user to the 'tss' group"),
|
||||
fail("identity_ready", "no usable identity", "run: lark-cli auth status --verify"),
|
||||
skip("endpoint_open", "skipped (--offline)"),
|
||||
}
|
||||
renderDoctorHuman(&buf, "local", checks, false, false)
|
||||
out := buf.String()
|
||||
|
||||
for _, want := range []string{
|
||||
"lark-cli doctor", "workspace: local",
|
||||
"[PASS]", "cli_version", "1.0.50",
|
||||
"[WARN]", "tee_signer", "↳ add your user to the 'tss' group",
|
||||
"[FAIL]", "identity_ready", "↳ run: lark-cli auth status --verify",
|
||||
"[SKIP]", "endpoint_open",
|
||||
"problems found", "1 passed", "1 warning(s)", "1 failed", "1 skipped",
|
||||
} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Errorf("output missing %q\n---\n%s", want, out)
|
||||
}
|
||||
}
|
||||
if strings.Contains(out, "\033[") {
|
||||
t.Errorf("color=false but ANSI escapes present:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func assertCheck(t *testing.T, checks []checkResult, name, status string) {
|
||||
t.Helper()
|
||||
for _, check := range checks {
|
||||
|
||||
@@ -26,7 +26,6 @@ func TestRunList_TextOutput(t *testing.T) {
|
||||
"KEY", "AUTH", "PARAMS", "DESCRIPTION",
|
||||
"im.message.receive_v1",
|
||||
"im.message.message_read_v1",
|
||||
"task.task.update_user_access_v2",
|
||||
} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Errorf("list output missing %q; full output:\n%s", want, out)
|
||||
@@ -56,17 +55,4 @@ func TestRunList_JSONOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var foundTask bool
|
||||
for _, row := range rows {
|
||||
if row["key"] == "task.task.update_user_access_v2" {
|
||||
foundTask = true
|
||||
if row["single_consumer"] != true {
|
||||
t.Errorf("task row single_consumer = %v, want true", row["single_consumer"])
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundTask {
|
||||
t.Fatal("event list JSON missing task.task.update_user_access_v2")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,34 +96,6 @@ func TestRunSchema_JSONOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSchema_TaskUpdateUserAccessJSON(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test"})
|
||||
|
||||
if err := runSchema(f, "task.task.update_user_access_v2", true); err != nil {
|
||||
t.Fatalf("runSchema json: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("output is not valid JSON: %v\n%s", err, stdout.String())
|
||||
}
|
||||
if payload["jq_root_path"] != ".event" {
|
||||
t.Errorf("jq_root_path = %v, want .event", payload["jq_root_path"])
|
||||
}
|
||||
if payload["single_consumer"] != true {
|
||||
t.Errorf("single_consumer = %v, want true", payload["single_consumer"])
|
||||
}
|
||||
resolved := payload["resolved_output_schema"].(map[string]interface{})
|
||||
props := resolved["properties"].(map[string]interface{})
|
||||
eventProps := props["event"].(map[string]interface{})["properties"].(map[string]interface{})
|
||||
if got := eventProps["task_guid"].(map[string]interface{})["format"]; got != "task_guid" {
|
||||
t.Errorf("task_guid format = %v, want task_guid", got)
|
||||
}
|
||||
if _, ok := eventProps["event_types"].(map[string]interface{})["items"].(map[string]interface{})["enum"]; !ok {
|
||||
t.Fatalf("event_types enum missing in schema: %#v", eventProps["event_types"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchema_RendersSubscriptionKeyMarker(t *testing.T) {
|
||||
const syntheticKey = "test.evt_sub"
|
||||
t.Cleanup(func() { eventlib.UnregisterKeyForTest(syntheticKey) })
|
||||
|
||||
@@ -7,7 +7,6 @@ package events
|
||||
import (
|
||||
"github.com/larksuite/cli/events/im"
|
||||
"github.com/larksuite/cli/events/minutes"
|
||||
"github.com/larksuite/cli/events/task"
|
||||
"github.com/larksuite/cli/events/vc"
|
||||
"github.com/larksuite/cli/events/whiteboard"
|
||||
"github.com/larksuite/cli/internal/event"
|
||||
@@ -18,7 +17,6 @@ func init() {
|
||||
all := [][]event.KeyDefinition{
|
||||
im.Keys(),
|
||||
minutes.Keys(),
|
||||
task.Keys(),
|
||||
vc.Keys(),
|
||||
whiteboard.Keys(),
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
// TaskUpdateUserAccessV2Data is the Task v2 update event payload under the
|
||||
// standard Lark V2 event envelope.
|
||||
type TaskUpdateUserAccessV2Data struct {
|
||||
EventTypes []string `json:"event_types,omitempty" desc:"Task commit types included in this event" enum:"task_create,task_deleted,task_summary_update,task_desc_update,task_assignees_update,task_followers_update,task_reminders_update,task_start_due_update,task_completed_update"`
|
||||
TaskGUID string `json:"task_guid,omitempty" desc:"Task GUID that changed" kind:"task_guid"`
|
||||
}
|
||||
|
||||
var taskUpdateUserAccessCommitTypes = []string{
|
||||
"task_create",
|
||||
"task_deleted",
|
||||
"task_summary_update",
|
||||
"task_desc_update",
|
||||
"task_assignees_update",
|
||||
"task_followers_update",
|
||||
"task_reminders_update",
|
||||
"task_start_due_update",
|
||||
"task_completed_update",
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/event"
|
||||
)
|
||||
|
||||
const taskSubscriptionPath = "/open-apis/task/v2/task_v2/task_subscription?user_id_type=open_id"
|
||||
|
||||
func taskSubscriptionPreConsume(ctx context.Context, rt event.APIClient, _ map[string]string) (func() error, error) {
|
||||
if rt == nil {
|
||||
return nil, errs.NewInternalError(errs.SubtypeUnknown,
|
||||
"runtime API client is required for pre-consume subscription")
|
||||
}
|
||||
|
||||
if _, err := rt.CallAPI(ctx, "POST", taskSubscriptionPath, nil); err != nil {
|
||||
if _, ok := errs.ProblemOf(err); ok {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errs.NewNetworkError(
|
||||
errs.SubtypeNetworkTransport,
|
||||
"failed to subscribe task event",
|
||||
).WithCause(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
)
|
||||
|
||||
type stubAPIClient struct {
|
||||
err error
|
||||
|
||||
method string
|
||||
path string
|
||||
body interface{}
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *stubAPIClient) CallAPI(_ context.Context, method, path string, body interface{}) (json.RawMessage, error) {
|
||||
s.method = method
|
||||
s.path = path
|
||||
s.body = body
|
||||
s.calls++
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return json.RawMessage(`{"code":0,"msg":"success","data":{}}`), nil
|
||||
}
|
||||
|
||||
func TestTaskSubscriptionPreConsumeCallsSubscribeAPI(t *testing.T) {
|
||||
rt := &stubAPIClient{}
|
||||
cleanup, err := taskSubscriptionPreConsume(context.Background(), rt, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("taskSubscriptionPreConsume error = %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
t.Fatal("cleanup = non-nil, want nil because task subscription has no unsubscribe API")
|
||||
}
|
||||
if rt.calls != 1 {
|
||||
t.Fatalf("calls = %d, want 1", rt.calls)
|
||||
}
|
||||
if rt.method != "POST" {
|
||||
t.Errorf("method = %q, want POST", rt.method)
|
||||
}
|
||||
if rt.path != taskSubscriptionPath {
|
||||
t.Errorf("path = %q, want %q", rt.path, taskSubscriptionPath)
|
||||
}
|
||||
if rt.body != nil {
|
||||
t.Errorf("body = %#v, want nil", rt.body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskSubscriptionPreConsumeRequiresRuntime(t *testing.T) {
|
||||
_, err := taskSubscriptionPreConsume(context.Background(), nil, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected typed error, got %T: %v", err, err)
|
||||
}
|
||||
if p.Category != errs.CategoryInternal {
|
||||
t.Errorf("category = %s, want %s", p.Category, errs.CategoryInternal)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeUnknown {
|
||||
t.Errorf("subtype = %s, want %s", p.Subtype, errs.SubtypeUnknown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskSubscriptionPreConsumePassesThroughAPIError(t *testing.T) {
|
||||
wantErr := errs.NewValidationError(errs.SubtypeFailedPrecondition, "subscription already exists")
|
||||
rt := &stubAPIClient{err: wantErr}
|
||||
|
||||
_, err := taskSubscriptionPreConsume(context.Background(), rt, nil)
|
||||
if err != wantErr {
|
||||
t.Fatalf("err identity changed: got %T %v, want original %T %v", err, err, wantErr, wantErr)
|
||||
}
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("err = %v, want %v", err, wantErr)
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected typed error, got %T: %v", err, err)
|
||||
}
|
||||
if p.Category != errs.CategoryValidation {
|
||||
t.Errorf("category = %s, want %s", p.Category, errs.CategoryValidation)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeFailedPrecondition {
|
||||
t.Errorf("subtype = %s, want %s", p.Subtype, errs.SubtypeFailedPrecondition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskSubscriptionPreConsumeWrapsUntypedAPIError(t *testing.T) {
|
||||
cause := errors.New("connection reset")
|
||||
rt := &stubAPIClient{err: cause}
|
||||
|
||||
_, err := taskSubscriptionPreConsume(context.Background(), rt, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !errors.Is(err, cause) {
|
||||
t.Fatalf("err = %v, want cause %v", err, cause)
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected typed error, got %T: %v", err, err)
|
||||
}
|
||||
if p.Category != errs.CategoryNetwork {
|
||||
t.Errorf("category = %s, want %s", p.Category, errs.CategoryNetwork)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeNetworkTransport {
|
||||
t.Errorf("subtype = %s, want %s", p.Subtype, errs.SubtypeNetworkTransport)
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Package task registers Task-domain EventKeys.
|
||||
package task
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/larksuite/cli/internal/event"
|
||||
)
|
||||
|
||||
const eventTypeTaskUpdateUserAccessV2 = "task.task.update_user_access_v2"
|
||||
|
||||
// Keys returns all Task-domain EventKey definitions.
|
||||
func Keys() []event.KeyDefinition {
|
||||
return []event.KeyDefinition{
|
||||
{
|
||||
Key: eventTypeTaskUpdateUserAccessV2,
|
||||
DisplayName: "Task updated",
|
||||
Description: "Triggered when tasks visible to the current user or app are created, deleted, or updated",
|
||||
EventType: eventTypeTaskUpdateUserAccessV2,
|
||||
Schema: event.SchemaDef{
|
||||
Native: &event.SchemaSpec{Type: reflect.TypeOf(TaskUpdateUserAccessV2Data{})},
|
||||
},
|
||||
PreConsume: taskSubscriptionPreConsume,
|
||||
Scopes: []string{"task:task:read"},
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
RequiredConsoleEvents: []string{eventTypeTaskUpdateUserAccessV2},
|
||||
SingleConsumer: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/event"
|
||||
"github.com/larksuite/cli/internal/event/schemas"
|
||||
)
|
||||
|
||||
func TestKeysTaskUpdateUserAccessMetadata(t *testing.T) {
|
||||
keys := Keys()
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("len(Keys()) = %d, want 1", len(keys))
|
||||
}
|
||||
|
||||
def := keys[0]
|
||||
if def.Key != eventTypeTaskUpdateUserAccessV2 {
|
||||
t.Errorf("Key = %q, want %q", def.Key, eventTypeTaskUpdateUserAccessV2)
|
||||
}
|
||||
if def.EventType != eventTypeTaskUpdateUserAccessV2 {
|
||||
t.Errorf("EventType = %q, want %q", def.EventType, eventTypeTaskUpdateUserAccessV2)
|
||||
}
|
||||
if def.Schema.Native == nil {
|
||||
t.Fatal("Schema.Native is nil")
|
||||
}
|
||||
if def.Schema.Native.Type != reflect.TypeOf(TaskUpdateUserAccessV2Data{}) {
|
||||
t.Errorf("native type = %v, want TaskUpdateUserAccessV2Data", def.Schema.Native.Type)
|
||||
}
|
||||
if def.Process != nil {
|
||||
t.Fatal("Native Task EventKey must not set Process")
|
||||
}
|
||||
if def.PreConsume == nil {
|
||||
t.Fatal("PreConsume is nil")
|
||||
}
|
||||
if !def.SingleConsumer {
|
||||
t.Fatal("SingleConsumer = false, want true")
|
||||
}
|
||||
if !reflect.DeepEqual(def.Scopes, []string{"task:task:read"}) {
|
||||
t.Errorf("Scopes = %#v", def.Scopes)
|
||||
}
|
||||
if !reflect.DeepEqual(def.AuthTypes, []string{"user", "bot"}) {
|
||||
t.Errorf("AuthTypes = %#v", def.AuthTypes)
|
||||
}
|
||||
if !reflect.DeepEqual(def.RequiredConsoleEvents, []string{eventTypeTaskUpdateUserAccessV2}) {
|
||||
t.Errorf("RequiredConsoleEvents = %#v", def.RequiredConsoleEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskUpdateUserAccessSchemaAnnotations(t *testing.T) {
|
||||
raw := schemas.WrapV2Envelope(schemas.FromType(reflect.TypeOf(TaskUpdateUserAccessV2Data{})))
|
||||
var schema map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &schema); err != nil {
|
||||
t.Fatalf("unmarshal schema: %v", err)
|
||||
}
|
||||
|
||||
eventProps := schema["properties"].(map[string]interface{})["event"].(map[string]interface{})["properties"].(map[string]interface{})
|
||||
taskGUID := eventProps["task_guid"].(map[string]interface{})
|
||||
if got := taskGUID["format"]; got != "task_guid" {
|
||||
t.Errorf("task_guid format = %v, want task_guid", got)
|
||||
}
|
||||
|
||||
eventTypes := eventProps["event_types"].(map[string]interface{})
|
||||
items := eventTypes["items"].(map[string]interface{})
|
||||
rawEnum, ok := items["enum"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("event_types item enum missing: %#v", items["enum"])
|
||||
}
|
||||
got := make(map[string]bool, len(rawEnum))
|
||||
for _, v := range rawEnum {
|
||||
got[v.(string)] = true
|
||||
}
|
||||
for _, want := range taskUpdateUserAccessCommitTypes {
|
||||
if !got[want] {
|
||||
t.Errorf("event_types enum missing %q; enum=%v", want, rawEnum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskUpdateUserAccessRegistersCleanly(t *testing.T) {
|
||||
const key = eventTypeTaskUpdateUserAccessV2
|
||||
event.UnregisterKeyForTest(key)
|
||||
t.Cleanup(func() { event.UnregisterKeyForTest(key) })
|
||||
|
||||
for _, def := range Keys() {
|
||||
event.RegisterKey(def)
|
||||
}
|
||||
if _, ok := event.Lookup(key); !ok {
|
||||
t.Fatalf("event.Lookup(%q) not registered", key)
|
||||
}
|
||||
}
|
||||
18
go.mod
18
go.mod
@@ -7,6 +7,8 @@ require (
|
||||
github.com/bmatcuk/doublestar/v4 v4.10.0
|
||||
github.com/charmbracelet/huh v1.0.0
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/facebookincubator/flog v0.0.0-20190930132826-d2511d0ce33c
|
||||
github.com/facebookincubator/sks v0.0.0-20251112220143-6823f23937b4
|
||||
github.com/gofrs/flock v0.8.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/itchyny/gojq v0.12.17
|
||||
@@ -27,7 +29,10 @@ require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require github.com/ebitengine/purego v0.10.1
|
||||
|
||||
require (
|
||||
github.com/StackExchange/wmi v1.2.1 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
@@ -42,12 +47,23 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/go-ole/go-ole v1.2.5 // indirect
|
||||
github.com/godbus/dbus/v5 v5.2.2 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/certificate-transparency-go v1.1.2 // indirect
|
||||
github.com/google/certtostore v1.0.3-0.20230404221207-8d01647071cc // indirect
|
||||
github.com/google/deck v0.0.0-20230104221208-105ad94aa8ae // indirect
|
||||
github.com/google/go-attestation v0.5.1 // indirect
|
||||
github.com/google/go-tpm v0.9.0 // indirect
|
||||
github.com/google/go-tspi v0.3.0 // indirect
|
||||
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/itchyny/timefmt-go v0.1.6 // indirect
|
||||
github.com/jgoguen/go-utils v0.0.0-20200211015258-b42ad41486fd // indirect
|
||||
github.com/jtolds/gls v4.20.0+incompatible // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -57,10 +73,12 @@ require (
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/smarty/assertions v1.15.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
)
|
||||
|
||||
@@ -31,6 +31,11 @@ type AppRegistrationResult struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
UserInfo *AppRegUserInfo
|
||||
// AuthMethods is the authoritative auth method(s) the app must use, as
|
||||
// decided by the user/admin at confirmation (20260409 `auth_method` field).
|
||||
// It may differ from what the client requested — e.g. selecting an existing
|
||||
// client_secret app. Empty on older servers.
|
||||
AuthMethods []string
|
||||
}
|
||||
|
||||
// AppRegUserInfo contains user info returned from app registration.
|
||||
@@ -39,8 +44,81 @@ type AppRegUserInfo struct {
|
||||
TenantBrand string // "feishu" or "lark"
|
||||
}
|
||||
|
||||
// RequestAppRegistration initiates the app registration device flow.
|
||||
func RequestAppRegistration(httpClient *http.Client, brand core.LarkBrand, errOut io.Writer) (*AppRegistrationResponse, error) {
|
||||
// AppRegistrationInit is the response from the app registration init endpoint.
|
||||
type AppRegistrationInit struct {
|
||||
Nonce string
|
||||
SupportedAuthMethods []string // e.g. ["client_secret", "private_key_jwt"]
|
||||
}
|
||||
|
||||
// AppRegistrationBeginOptions parametrizes the registration begin request.
|
||||
// A zero value selects the legacy client_secret flow, preserving prior behavior.
|
||||
type AppRegistrationBeginOptions struct {
|
||||
AuthMethod string // "" => client_secret; core.AuthMethodPrivateKeyJWT
|
||||
AuthAttestation string // private_key_jwt: the TEE-signed attestation JWT
|
||||
RestoreAppID string // when set, asks the server to re-register this existing app
|
||||
}
|
||||
|
||||
// RequestAppRegistrationInit performs the init step of the registration flow,
|
||||
// returning a server nonce (to be embedded in a TEE-signed attestation JWT) and
|
||||
// the auth methods the server supports for this archetype.
|
||||
func RequestAppRegistrationInit(httpClient *http.Client) (*AppRegistrationInit, error) {
|
||||
// Registration always begins against the feishu accounts host (mirrors begin).
|
||||
endpoint := core.ResolveEndpoints(core.BrandFeishu).Accounts + PathAppRegistration
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("action", "init")
|
||||
form.Set("archetype", "PersonalAgent")
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
logHTTPResponse(resp)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("app registration init failed: read body: %w", err)
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, fmt.Errorf("app registration init failed: HTTP %d – response not JSON", resp.StatusCode)
|
||||
}
|
||||
|
||||
if _, hasError := data["error"]; resp.StatusCode >= 400 || hasError {
|
||||
msg := getStr(data, "error_description")
|
||||
if msg == "" {
|
||||
msg = getStr(data, "error")
|
||||
}
|
||||
if msg == "" {
|
||||
msg = "Unknown error"
|
||||
}
|
||||
return nil, fmt.Errorf("app registration init failed: %s", msg)
|
||||
}
|
||||
|
||||
out := &AppRegistrationInit{Nonce: getStr(data, "nonce")}
|
||||
if methods, ok := data["supported_auth_methods"].([]interface{}); ok {
|
||||
for _, m := range methods {
|
||||
if s, ok := m.(string); ok {
|
||||
out.SupportedAuthMethods = append(out.SupportedAuthMethods, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if out.Nonce == "" {
|
||||
return nil, fmt.Errorf("app registration init failed: server returned no nonce")
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// RequestAppRegistration initiates the app registration device flow (begin step).
|
||||
func RequestAppRegistration(httpClient *http.Client, brand core.LarkBrand, opts AppRegistrationBeginOptions, errOut io.Writer) (*AppRegistrationResponse, error) {
|
||||
if errOut == nil {
|
||||
errOut = io.Discard
|
||||
}
|
||||
@@ -49,11 +127,24 @@ func RequestAppRegistration(httpClient *http.Client, brand core.LarkBrand, errOu
|
||||
regEp := core.ResolveEndpoints(core.BrandFeishu) // registration begin always uses feishu
|
||||
endpoint := regEp.Accounts + PathAppRegistration
|
||||
|
||||
authMethod := opts.AuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = core.AuthMethodClientSecret
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("action", "begin")
|
||||
form.Set("archetype", "PersonalAgent")
|
||||
form.Set("auth_method", "client_secret")
|
||||
form.Set("auth_method", authMethod)
|
||||
form.Set("request_user_info", "open_id tenant_brand")
|
||||
if opts.AuthAttestation != "" {
|
||||
form.Set("auth_attestation", opts.AuthAttestation)
|
||||
}
|
||||
// Restore flow: carry the existing app id so the server re-registers it
|
||||
// rather than creating a new app.
|
||||
if opts.RestoreAppID != "" {
|
||||
form.Set("app_id", opts.RestoreAppID)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
@@ -95,7 +186,24 @@ func RequestAppRegistration(httpClient *http.Client, brand core.LarkBrand, errOu
|
||||
|
||||
userCode := getStr(data, "user_code")
|
||||
verificationUri := getStr(data, "verification_uri")
|
||||
verificationUriComplete := fmt.Sprintf("%s/page/cli?user_code=%s", ep.Open, userCode)
|
||||
// Prefer the server-provided complete URL (currently /page/launcher); fall
|
||||
// back to building it from verification_uri, then to /page/launcher. The old
|
||||
// hard-coded /page/cli is stale — the server now returns /page/launcher.
|
||||
verificationUriComplete := getStr(data, "verification_uri_complete")
|
||||
if verificationUriComplete == "" {
|
||||
base := verificationUri
|
||||
if base == "" {
|
||||
base = ep.Open + "/page/launcher"
|
||||
}
|
||||
// The server may return verification_uri with its own query (e.g.
|
||||
// client_id when registering against an existing app), so join with
|
||||
// the same ?/& logic as BuildVerificationURL.
|
||||
sep := "?"
|
||||
if strings.Contains(base, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
verificationUriComplete = base + sep + "user_code=" + url.QueryEscape(userCode)
|
||||
}
|
||||
|
||||
return &AppRegistrationResponse{
|
||||
DeviceCode: getStr(data, "device_code"),
|
||||
@@ -107,6 +215,26 @@ func RequestAppRegistration(httpClient *http.Client, brand core.LarkBrand, errOu
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseAuthMethods normalizes the poll response `auth_method` field, which the
|
||||
// server returns as a JSON array of strings (e.g. ["private_key_jwt"]) — or, on
|
||||
// some variants, a single space-separated string.
|
||||
func parseAuthMethods(v interface{}) []string {
|
||||
switch t := v.(type) {
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, m := range t {
|
||||
if s, ok := m.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case string:
|
||||
return strings.Fields(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// BuildVerificationURL appends CLI tracking parameters to the verification URL.
|
||||
func BuildVerificationURL(baseURL, cliVersion string) string {
|
||||
sep := "&"
|
||||
@@ -187,6 +315,7 @@ func PollAppRegistration(ctx context.Context, httpClient *http.Client, brand cor
|
||||
result := &AppRegistrationResult{
|
||||
ClientID: getStr(data, "client_id"),
|
||||
ClientSecret: getStr(data, "client_secret"),
|
||||
AuthMethods: parseAuthMethods(data["auth_method"]),
|
||||
}
|
||||
if userInfoRaw, ok := data["user_info"].(map[string]interface{}); ok {
|
||||
result.UserInfo = &AppRegUserInfo{
|
||||
|
||||
@@ -4,8 +4,14 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
@@ -31,3 +37,184 @@ func Test_BuildVerificationURL(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// captureClient returns an http.Client that records the last request's form body
|
||||
// and replies with the given JSON payload.
|
||||
func captureClient(gotBody *url.Values, respJSON string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
v, _ := url.ParseQuery(string(b))
|
||||
*gotBody = v
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(respJSON)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAppRegistrationInit_ParsesNonceAndMethods(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, `{"nonce":"n-123","supported_auth_methods":["client_secret","private_key_jwt"]}`)
|
||||
|
||||
out, err := RequestAppRegistrationInit(hc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out.Nonce != "n-123" {
|
||||
t.Errorf("nonce = %q, want n-123", out.Nonce)
|
||||
}
|
||||
if len(out.SupportedAuthMethods) != 2 || out.SupportedAuthMethods[1] != "private_key_jwt" {
|
||||
t.Errorf("methods = %v", out.SupportedAuthMethods)
|
||||
}
|
||||
if body.Get("action") != "init" {
|
||||
t.Errorf("action = %q, want init", body.Get("action"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAppRegistrationInit_ErrorOnMissingNonce(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, `{"supported_auth_methods":["client_secret"]}`)
|
||||
if _, err := RequestAppRegistrationInit(hc); err == nil {
|
||||
t.Fatal("expected error when server returns no nonce")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAppRegistrationInit_EmptySupportedAuthMethods covers the older-server
|
||||
// back-compat path: an empty supported_auth_methods array parses to an empty
|
||||
// slice, so the init guard in cmd/config/init_interactive.go
|
||||
// (`len(SupportedAuthMethods) > 0 && !slices.Contains(...)`) stays false and does
|
||||
// NOT reject the requested private_key_jwt. This aligns with
|
||||
// resolveFinalAuthMethod(nil/[], private_key_jwt) == private_key_jwt
|
||||
// (see cmd/config TestResolveFinalAuthMethod).
|
||||
func TestRequestAppRegistrationInit_EmptySupportedAuthMethods(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, `{"nonce":"n-1","supported_auth_methods":[]}`)
|
||||
|
||||
out, err := RequestAppRegistrationInit(hc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out.Nonce != "n-1" {
|
||||
t.Errorf("nonce = %q, want n-1", out.Nonce)
|
||||
}
|
||||
if len(out.SupportedAuthMethods) != 0 {
|
||||
t.Errorf("SupportedAuthMethods = %v, want empty", out.SupportedAuthMethods)
|
||||
}
|
||||
// Reproduce the init guard expression on the real parsed result: an empty
|
||||
// slice must NOT reject private_key_jwt.
|
||||
rejected := len(out.SupportedAuthMethods) > 0 &&
|
||||
!slices.Contains(out.SupportedAuthMethods, core.AuthMethodPrivateKeyJWT)
|
||||
if rejected {
|
||||
t.Error("empty SupportedAuthMethods must allow private_key_jwt (older-server back-compat)")
|
||||
}
|
||||
}
|
||||
|
||||
const beginRespJSON = `{"device_code":"dc","user_code":"uc","verification_uri":"https://example/verify","expires_in":300,"interval":5}`
|
||||
|
||||
func TestRequestAppRegistration_BeginDefaultsToClientSecret(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, beginRespJSON)
|
||||
|
||||
if _, err := RequestAppRegistration(hc, core.BrandFeishu, AppRegistrationBeginOptions{}, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body.Get("action") != "begin" {
|
||||
t.Errorf("action = %q", body.Get("action"))
|
||||
}
|
||||
if body.Get("auth_method") != "client_secret" {
|
||||
t.Errorf("auth_method = %q, want client_secret (default)", body.Get("auth_method"))
|
||||
}
|
||||
if body.Has("auth_attestation") {
|
||||
t.Errorf("auth_attestation should be absent for client_secret, got %q", body.Get("auth_attestation"))
|
||||
}
|
||||
// Normal (non-restore) begin must NOT carry app_id.
|
||||
if body.Has("app_id") {
|
||||
t.Errorf("app_id should be absent when RestoreAppID is empty, got %q", body.Get("app_id"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAppRegistration_BeginRestoreAppID verifies the restore flow sends the
|
||||
// existing app id on begin so the server re-registers that app.
|
||||
func TestRequestAppRegistration_BeginRestoreAppID(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, beginRespJSON)
|
||||
|
||||
opts := AppRegistrationBeginOptions{RestoreAppID: "cli_restore_me"}
|
||||
if _, err := RequestAppRegistration(hc, core.BrandFeishu, opts, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body.Get("action") != "begin" {
|
||||
t.Errorf("action = %q, want begin", body.Get("action"))
|
||||
}
|
||||
if body.Get("app_id") != "cli_restore_me" {
|
||||
t.Errorf("app_id = %q, want cli_restore_me", body.Get("app_id"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAppRegistration_VerificationURICompleteFallback(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
resp string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "bare verification_uri",
|
||||
resp: `{"device_code":"dc","user_code":"uc","verification_uri":"https://example/verify","expires_in":300,"interval":5}`,
|
||||
want: "https://example/verify?user_code=uc",
|
||||
},
|
||||
{
|
||||
name: "verification_uri with existing query",
|
||||
resp: `{"device_code":"dc","user_code":"uc","verification_uri":"https://example/verify?client_id=cli_x","expires_in":300,"interval":5}`,
|
||||
want: "https://example/verify?client_id=cli_x&user_code=uc",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, tc.resp)
|
||||
got, err := RequestAppRegistration(hc, core.BrandFeishu, AppRegistrationBeginOptions{}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got.VerificationUriComplete != tc.want {
|
||||
t.Errorf("VerificationUriComplete = %q, want %q", got.VerificationUriComplete, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthMethods(t *testing.T) {
|
||||
if got := parseAuthMethods([]interface{}{"private_key_jwt", "client_secret"}); len(got) != 2 || got[0] != "private_key_jwt" {
|
||||
t.Errorf("array form = %v", got)
|
||||
}
|
||||
if got := parseAuthMethods("client_secret private_key_jwt"); len(got) != 2 || got[1] != "private_key_jwt" {
|
||||
t.Errorf("string form = %v", got)
|
||||
}
|
||||
if got := parseAuthMethods(nil); got != nil {
|
||||
t.Errorf("nil form = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAppRegistration_BeginPrivateKeyJWT(t *testing.T) {
|
||||
var body url.Values
|
||||
hc := captureClient(&body, beginRespJSON)
|
||||
|
||||
opts := AppRegistrationBeginOptions{
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
AuthAttestation: "header.claims.sig",
|
||||
}
|
||||
if _, err := RequestAppRegistration(hc, core.BrandFeishu, opts, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body.Get("auth_method") != "private_key_jwt" {
|
||||
t.Errorf("auth_method = %q, want private_key_jwt", body.Get("auth_method"))
|
||||
}
|
||||
if body.Get("auth_attestation") != "header.claims.sig" {
|
||||
t.Errorf("auth_attestation = %q", body.Get("auth_attestation"))
|
||||
}
|
||||
}
|
||||
|
||||
63
internal/auth/client_auth.go
Normal file
63
internal/auth/client_auth.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/larksuite/cli/internal/auth/jwt"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// ClientAuth describes how to authenticate the OAuth client at the token
|
||||
// endpoint: with a client_secret (default) or a TEE-signed client_assertion
|
||||
// (private_key_jwt).
|
||||
type ClientAuth struct {
|
||||
AppID string
|
||||
AppSecret string
|
||||
AuthMethod string // "" == client_secret; core.AuthMethodPrivateKeyJWT
|
||||
Signer keysigner.Signer
|
||||
KeyLabel string
|
||||
}
|
||||
|
||||
// ClientAuthFromConfig builds a ClientAuth from resolved config, picking up the
|
||||
// active key signer for private_key_jwt apps.
|
||||
func ClientAuthFromConfig(cfg *core.CliConfig) ClientAuth {
|
||||
if cfg == nil {
|
||||
return ClientAuth{}
|
||||
}
|
||||
return ClientAuth{
|
||||
AppID: cfg.AppID,
|
||||
AppSecret: cfg.AppSecret,
|
||||
AuthMethod: cfg.AuthMethod,
|
||||
KeyLabel: cfg.KeyLabel,
|
||||
Signer: keysigner.Active(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c ClientAuth) isPrivateKeyJWT() bool { return c.AuthMethod == core.AuthMethodPrivateKeyJWT }
|
||||
|
||||
// applyClientAssertion adds client_assertion(+type) to a token-endpoint form for
|
||||
// private_key_jwt and returns true. For client_secret it returns false, leaving
|
||||
// the caller to apply its own secret-based authentication. audience is the token
|
||||
// endpoint URL (the assertion's aud claim).
|
||||
func (c ClientAuth) applyClientAssertion(ctx context.Context, form url.Values, audience string) (bool, error) {
|
||||
if !c.isPrivateKeyJWT() {
|
||||
return false, nil
|
||||
}
|
||||
if c.Signer == nil {
|
||||
return false, fmt.Errorf("private_key_jwt requires a key signer, but none is available on this build")
|
||||
}
|
||||
assertion, err := jwt.SignClientAssertion(ctx, c.Signer, keysigner.KeyRef{Label: c.KeyLabel}, c.AppID, audience, time.Now())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
form.Set("client_assertion_type", jwt.ClientAssertionType)
|
||||
form.Set("client_assertion", assertion)
|
||||
return true, nil
|
||||
}
|
||||
109
internal/auth/client_auth_test.go
Normal file
109
internal/auth/client_auth_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/auth/jwt"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// fakeAuthSigner is a real in-memory ECDSA P-256 signer for client-auth tests.
|
||||
type fakeAuthSigner struct{ key *ecdsa.PrivateKey }
|
||||
|
||||
func newFakeAuthSigner(t *testing.T) *fakeAuthSigner {
|
||||
t.Helper()
|
||||
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &fakeAuthSigner{key: k}
|
||||
}
|
||||
|
||||
func (f *fakeAuthSigner) EnsureKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeAuthSigner) PublicKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeAuthSigner) Sign(_ context.Context, _ keysigner.KeyRef, in []byte) ([]byte, string, error) {
|
||||
h := sha256.Sum256(in)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, f.key, h[:])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
sig := make([]byte, 64)
|
||||
r.FillBytes(sig[:32])
|
||||
s.FillBytes(sig[32:])
|
||||
return sig, keysigner.AlgES256, nil
|
||||
}
|
||||
|
||||
func TestClientAuth_applyClientAssertion_ClientSecret(t *testing.T) {
|
||||
ca := ClientAuth{AppID: "cli_a", AppSecret: "sec"} // AuthMethod "" => client_secret
|
||||
form := url.Values{}
|
||||
used, err := ca.applyClientAssertion(context.Background(), form, "https://aud/token")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if used {
|
||||
t.Error("client_secret must not produce a client_assertion")
|
||||
}
|
||||
if form.Has("client_assertion") || form.Has("client_assertion_type") {
|
||||
t.Errorf("form should be untouched, got %v", form)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAuth_applyClientAssertion_PrivateKeyJWT(t *testing.T) {
|
||||
ca := ClientAuth{
|
||||
AppID: "cli_a",
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
Signer: newFakeAuthSigner(t),
|
||||
KeyLabel: "k",
|
||||
}
|
||||
form := url.Values{}
|
||||
used, err := ca.applyClientAssertion(context.Background(), form, "https://accounts.feishu.cn/open-apis/authen/v2/oauth/token")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !used {
|
||||
t.Fatal("expected client_assertion to be applied")
|
||||
}
|
||||
if form.Get("client_assertion_type") != jwt.ClientAssertionType {
|
||||
t.Errorf("client_assertion_type = %q", form.Get("client_assertion_type"))
|
||||
}
|
||||
if form.Get("client_assertion") == "" {
|
||||
t.Error("client_assertion is empty")
|
||||
}
|
||||
if form.Has("client_secret") {
|
||||
t.Error("client_secret must NOT be present for private_key_jwt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAuth_applyClientAssertion_NilSigner(t *testing.T) {
|
||||
ca := ClientAuth{AppID: "cli_a", AuthMethod: core.AuthMethodPrivateKeyJWT} // Signer nil
|
||||
if _, err := ca.applyClientAssertion(context.Background(), url.Values{}, "aud"); err == nil {
|
||||
t.Fatal("expected error when private_key_jwt has no signer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAuthFromConfig(t *testing.T) {
|
||||
ca := ClientAuthFromConfig(&core.CliConfig{
|
||||
AppID: "cli_x",
|
||||
AppSecret: "s",
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
KeyLabel: "label-1",
|
||||
})
|
||||
if ca.AppID != "cli_x" || ca.AppSecret != "s" || ca.AuthMethod != core.AuthMethodPrivateKeyJWT || ca.KeyLabel != "label-1" {
|
||||
t.Errorf("ClientAuth = %+v", ca)
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func ResolveOAuthEndpoints(brand core.LarkBrand) OAuthEndpoints {
|
||||
}
|
||||
|
||||
// RequestDeviceAuthorization requests a device authorization code.
|
||||
func RequestDeviceAuthorization(httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, scope string, errOut io.Writer) (*DeviceAuthResponse, error) {
|
||||
func RequestDeviceAuthorization(ctx context.Context, httpClient *http.Client, ca ClientAuth, brand core.LarkBrand, scope string, errOut io.Writer) (*DeviceAuthResponse, error) {
|
||||
if errOut == nil {
|
||||
errOut = io.Discard
|
||||
}
|
||||
@@ -77,18 +77,26 @@ func RequestDeviceAuthorization(httpClient *http.Client, appId, appSecret string
|
||||
}
|
||||
}
|
||||
|
||||
basicAuth := base64.StdEncoding.EncodeToString([]byte(appId + ":" + appSecret))
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("client_id", appId)
|
||||
form.Set("client_id", ca.AppID)
|
||||
form.Set("scope", scope)
|
||||
|
||||
req, err := http.NewRequest("POST", endpoints.DeviceAuthorization, strings.NewReader(form.Encode()))
|
||||
// private_key_jwt authenticates the client with a signed assertion in the
|
||||
// body; client_secret uses HTTP Basic.
|
||||
usedAssertion, err := ca.applyClientAssertion(ctx, form, core.OpenAPIAudience(brand))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", endpoints.DeviceAuthorization, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", "Basic "+basicAuth)
|
||||
if !usedAssertion {
|
||||
basicAuth := base64.StdEncoding.EncodeToString([]byte(ca.AppID + ":" + ca.AppSecret))
|
||||
req.Header.Set("Authorization", "Basic "+basicAuth)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -139,7 +147,7 @@ func RequestDeviceAuthorization(httpClient *http.Client, appId, appSecret string
|
||||
}
|
||||
|
||||
// PollDeviceToken polls the token endpoint until authorization completes or times out.
|
||||
func PollDeviceToken(ctx context.Context, httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *DeviceFlowResult {
|
||||
func PollDeviceToken(ctx context.Context, httpClient *http.Client, ca ClientAuth, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *DeviceFlowResult {
|
||||
if errOut == nil {
|
||||
errOut = io.Discard
|
||||
}
|
||||
@@ -171,10 +179,16 @@ func PollDeviceToken(ctx context.Context, httpClient *http.Client, appId, appSec
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
form.Set("device_code", deviceCode)
|
||||
form.Set("client_id", appId)
|
||||
form.Set("client_secret", appSecret)
|
||||
form.Set("client_id", ca.AppID)
|
||||
usedAssertion, caErr := ca.applyClientAssertion(ctx, form, core.OpenAPIAudience(brand))
|
||||
if caErr != nil {
|
||||
return &DeviceFlowResult{OK: false, Error: "invalid_client", Message: caErr.Error()}
|
||||
}
|
||||
if !usedAssertion {
|
||||
form.Set("client_secret", ca.AppSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", endpoints.Token, strings.NewReader(form.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", endpoints.Token, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -83,7 +85,7 @@ func TestRequestDeviceAuthorization_LogsResponse(t *testing.T) {
|
||||
})
|
||||
t.Cleanup(restore)
|
||||
|
||||
_, err := RequestDeviceAuthorization(httpmock.NewClient(reg), "cli_a", "secret_b", core.BrandFeishu, "", nil)
|
||||
_, err := RequestDeviceAuthorization(context.Background(), httpmock.NewClient(reg), ClientAuth{AppID: "cli_a", AppSecret: "secret_b"}, core.BrandFeishu, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("RequestDeviceAuthorization() error: %v", err)
|
||||
}
|
||||
@@ -106,6 +108,66 @@ func TestRequestDeviceAuthorization_LogsResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// captureRT records the last request + body and returns a canned device-auth response.
|
||||
func captureDeviceAuthClient(gotReq **http.Request, gotBody *string, respJSON string) *http.Client {
|
||||
return &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
*gotReq = req
|
||||
if req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
*gotBody = string(b)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(respJSON)),
|
||||
}, nil
|
||||
})}
|
||||
}
|
||||
|
||||
const deviceAuthRespJSON = `{"device_code":"dc","user_code":"uc","verification_uri":"https://example/verify","expires_in":300,"interval":5}`
|
||||
|
||||
func TestRequestDeviceAuthorization_PrivateKeyJWT_UsesAssertionNotBasic(t *testing.T) {
|
||||
var req *http.Request
|
||||
var body string
|
||||
client := captureDeviceAuthClient(&req, &body, deviceAuthRespJSON)
|
||||
|
||||
ca := ClientAuth{AppID: "cli_a", AuthMethod: core.AuthMethodPrivateKeyJWT, Signer: newFakeAuthSigner(t), KeyLabel: "k"}
|
||||
if _, err := RequestDeviceAuthorization(context.Background(), client, ca, core.BrandFeishu, "im:message:send", nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if req.Header.Get("Authorization") != "" {
|
||||
t.Errorf("private_key_jwt must NOT send Basic auth, got %q", req.Header.Get("Authorization"))
|
||||
}
|
||||
form, _ := url.ParseQuery(body)
|
||||
if form.Get("client_assertion") == "" {
|
||||
t.Error("missing client_assertion")
|
||||
}
|
||||
if form.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
|
||||
t.Errorf("client_assertion_type = %q", form.Get("client_assertion_type"))
|
||||
}
|
||||
if form.Has("client_secret") {
|
||||
t.Error("client_secret must not be present for private_key_jwt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDeviceAuthorization_ClientSecret_UsesBasic(t *testing.T) {
|
||||
var req *http.Request
|
||||
var body string
|
||||
client := captureDeviceAuthClient(&req, &body, deviceAuthRespJSON)
|
||||
|
||||
ca := ClientAuth{AppID: "cli_a", AppSecret: "sec"} // client_secret
|
||||
if _, err := RequestDeviceAuthorization(context.Background(), client, ca, core.BrandFeishu, "", nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(req.Header.Get("Authorization"), "Basic ") {
|
||||
t.Errorf("client_secret should use Basic auth, got %q", req.Header.Get("Authorization"))
|
||||
}
|
||||
form, _ := url.ParseQuery(body)
|
||||
if form.Has("client_assertion") {
|
||||
t.Error("client_secret must not send a client_assertion")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatAuthCmdline_TruncatesExtraArgs verifies that long command lines are truncated.
|
||||
func TestFormatAuthCmdline_TruncatesExtraArgs(t *testing.T) {
|
||||
got := keychain.FormatAuthCmdline([]string{
|
||||
@@ -205,7 +267,7 @@ func TestPollDeviceToken_DefaultsZeroIntervalToFiveSeconds(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
result := PollDeviceToken(ctx, client, "cli_a", "secret_b", core.BrandFeishu, "device-code", 0, 10, nil)
|
||||
result := PollDeviceToken(ctx, client, ClientAuth{AppID: "cli_a", AppSecret: "secret_b"}, core.BrandFeishu, "device-code", 0, 10, nil)
|
||||
if result == nil {
|
||||
t.Fatal("PollDeviceToken() returned nil result")
|
||||
}
|
||||
|
||||
153
internal/auth/jwt/jwt.go
Normal file
153
internal/auth/jwt/jwt.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Package jwt builds compact JWS tokens signed by a keysigner.Signer.
|
||||
//
|
||||
// It deliberately depends only on the standard library plus the existing
|
||||
// google/uuid dependency — no third-party JWT library is introduced, keeping
|
||||
// go.mod free of new dependencies. The actual signing (and, for ECDSA, the
|
||||
// ASN.1->r||s conversion) is delegated to the Signer implementation.
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
func b64(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) }
|
||||
|
||||
// buildSignedJWT builds a compact JWS:
|
||||
//
|
||||
// base64url(header).base64url(claims).base64url(signature)
|
||||
//
|
||||
// alg is written into the header (it is part of the signed input) and verified
|
||||
// against the alg the signer reports, guarding against a header/key mismatch.
|
||||
// typ defaults to "JWT": the server's client_assertion generalizedValidation
|
||||
// REQUIRES `typ == "JWT"` (rejects otherwise with "malformed client assertion
|
||||
// jwt"), even though the spec examples (§8.1/§8.2) show only alg.
|
||||
func buildSignedJWT(ctx context.Context, signer keysigner.Signer, ref keysigner.KeyRef, alg string, header, claims map[string]any) (string, error) {
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("jwt: no signer available (private_key_jwt unsupported on this build)")
|
||||
}
|
||||
if header == nil {
|
||||
header = map[string]any{}
|
||||
}
|
||||
header["alg"] = alg
|
||||
if _, ok := header["typ"]; !ok {
|
||||
header["typ"] = "JWT"
|
||||
}
|
||||
|
||||
hb, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("jwt: marshal header: %w", err)
|
||||
}
|
||||
cb, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("jwt: marshal claims: %w", err)
|
||||
}
|
||||
|
||||
signingInput := b64(hb) + "." + b64(cb)
|
||||
sig, gotAlg, err := signer.Sign(ctx, ref, []byte(signingInput))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("jwt: sign: %w", err)
|
||||
}
|
||||
if gotAlg != alg {
|
||||
return "", fmt.Errorf("jwt: signer alg %q does not match header alg %q", gotAlg, alg)
|
||||
}
|
||||
return signingInput + "." + b64(sig), nil
|
||||
}
|
||||
|
||||
// newJTI returns a random unique token identifier.
|
||||
func newJTI() string { return uuid.NewString() }
|
||||
|
||||
// attestationTTL bounds the attestation JWT's lifetime. The init nonce (60s,
|
||||
// single-use) is the real anti-replay constraint; this is a modest margin for
|
||||
// clock skew on top of the immediate init→sign→begin round-trip.
|
||||
const attestationTTL = 2 * time.Minute
|
||||
|
||||
// attestationClaims builds the registration attestation claim set per the App
|
||||
// Registration JWT spec: jti, iat, exp (all required) and the init-issued nonce.
|
||||
func attestationClaims(nonce string, now time.Time) map[string]any {
|
||||
return map[string]any{
|
||||
"jti": newJTI(),
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(attestationTTL).Unix(),
|
||||
"nonce": nonce,
|
||||
}
|
||||
}
|
||||
|
||||
// clientAssertionClaims builds an RFC 7523 client_assertion claim set used to
|
||||
// mint tokens in place of client_secret. aud is the brand's token endpoint URL.
|
||||
func clientAssertionClaims(clientID, aud string, now time.Time, ttl time.Duration) map[string]any {
|
||||
return map[string]any{
|
||||
"iss": clientID,
|
||||
"sub": clientID,
|
||||
"aud": aud,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(ttl).Unix(),
|
||||
"jti": newJTI(),
|
||||
}
|
||||
}
|
||||
|
||||
// ClientAssertionType is the RFC 7523 client_assertion_type value used for JWT
|
||||
// bearer client authentication at the token endpoint.
|
||||
const ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
|
||||
// defaultAssertionTTL bounds a client_assertion's lifetime.
|
||||
const defaultAssertionTTL = 5 * time.Minute
|
||||
|
||||
// SignAttestation signs the registration attestation JWT. The public key is
|
||||
// embedded in the JWS "jwk" header so the registration backend can bind it to
|
||||
// the app during action=begin; the claims carry the server nonce as a
|
||||
// proof-of-possession challenge.
|
||||
func SignAttestation(ctx context.Context, signer keysigner.Signer, ref keysigner.KeyRef, nonce string, now time.Time) (string, error) {
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("jwt: no signer available (private_key_jwt unsupported on this build)")
|
||||
}
|
||||
pub, err := signer.EnsureKey(ctx, ref)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("jwt: ensure key: %w", err)
|
||||
}
|
||||
alg, err := keysigner.AlgForKey(pub)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jwk, err := keysigner.PublicKeyJWK(pub)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buildSignedJWT(ctx, signer, ref, alg, map[string]any{"jwk": jwk}, attestationClaims(nonce, now))
|
||||
}
|
||||
|
||||
// SignClientAssertion mints a short-lived RFC 7523 client_assertion: it reads the
|
||||
// registered key (it must already exist — bound at registration; a missing key is
|
||||
// an error, not a reason to create a new unbound one), derives the JWS alg from
|
||||
// the public key, and signs an assertion whose audience is the brand's Open API
|
||||
// host. The server, holding the public key bound at registration, verifies it in
|
||||
// place of client_secret. The assertion header carries only alg (no jwk/kid);
|
||||
// the server locates the key via iss/sub = client_id.
|
||||
//
|
||||
// This is the model-independent glue: the assertion JWT is identical whether the
|
||||
// server augments an existing grant (device_code/refresh_token) with client
|
||||
// authentication or uses a dedicated jwt-bearer grant — only where the caller
|
||||
// attaches it differs.
|
||||
func SignClientAssertion(ctx context.Context, signer keysigner.Signer, ref keysigner.KeyRef, clientID, audience string, now time.Time) (string, error) {
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("jwt: no signer available (private_key_jwt unsupported on this build)")
|
||||
}
|
||||
pub, err := signer.PublicKey(ctx, ref)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("jwt: public key: %w", err)
|
||||
}
|
||||
alg, err := keysigner.AlgForKey(pub)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buildSignedJWT(ctx, signer, ref, alg, map[string]any{}, clientAssertionClaims(clientID, audience, now, defaultAssertionTTL))
|
||||
}
|
||||
254
internal/auth/jwt/jwt_test.go
Normal file
254
internal/auth/jwt/jwt_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// fakeSigner is a real in-memory ECDSA P-256 signer, so tests exercise the full
|
||||
// JWS path and the produced token is actually cryptographically verifiable.
|
||||
type fakeSigner struct{ key *ecdsa.PrivateKey }
|
||||
|
||||
func newFakeSigner(t *testing.T) *fakeSigner {
|
||||
t.Helper()
|
||||
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &fakeSigner{key: k}
|
||||
}
|
||||
|
||||
func (f *fakeSigner) EnsureKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeSigner) PublicKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeSigner) Sign(_ context.Context, _ keysigner.KeyRef, in []byte) ([]byte, string, error) {
|
||||
h := sha256.Sum256(in)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, f.key, h[:])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
// JOSE ES256: fixed-width big-endian r||s (32 bytes each for P-256).
|
||||
sig := make([]byte, 64)
|
||||
r.FillBytes(sig[:32])
|
||||
s.FillBytes(sig[32:])
|
||||
return sig, keysigner.AlgES256, nil
|
||||
}
|
||||
|
||||
func TestBuildSignedJWT_VerifiableES256(t *testing.T) {
|
||||
f := newFakeSigner(t)
|
||||
now := time.Unix(1700000000, 0)
|
||||
|
||||
tok, err := buildSignedJWT(context.Background(), f, keysigner.KeyRef{Label: "x"}, keysigner.AlgES256,
|
||||
map[string]any{}, clientAssertionClaims("cli_app", "https://accounts.example/token", now, 5*time.Minute))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
parts := strings.Split(tok, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("want 3 JWS parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
hb, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
t.Fatalf("header not base64url: %v", err)
|
||||
}
|
||||
var hdr map[string]any
|
||||
if err := json.Unmarshal(hb, &hdr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hdr["alg"] != "ES256" || hdr["typ"] != "JWT" {
|
||||
t.Errorf("header = %v, want alg=ES256 typ=JWT (server generalizedValidation requires typ)", hdr)
|
||||
}
|
||||
|
||||
cb, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(cb, &claims); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if claims["iss"] != "cli_app" || claims["sub"] != "cli_app" || claims["aud"] != "https://accounts.example/token" {
|
||||
t.Errorf("claims = %v", claims)
|
||||
}
|
||||
|
||||
// Cryptographically verify the signature against the signing input.
|
||||
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
t.Fatalf("sig not base64url: %v", err)
|
||||
}
|
||||
if len(sig) != 64 {
|
||||
t.Fatalf("ES256 sig len = %d, want 64", len(sig))
|
||||
}
|
||||
r := new(big.Int).SetBytes(sig[:32])
|
||||
s := new(big.Int).SetBytes(sig[32:])
|
||||
h := sha256.Sum256([]byte(parts[0] + "." + parts[1]))
|
||||
if !ecdsa.Verify(f.key.Public().(*ecdsa.PublicKey), h[:], r, s) {
|
||||
t.Error("signature did not verify")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSignedJWT_NilSigner(t *testing.T) {
|
||||
if _, err := buildSignedJWT(context.Background(), nil, keysigner.KeyRef{}, "ES256", nil, nil); err == nil {
|
||||
t.Fatal("expected error for nil signer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSignedJWT_AlgMismatch(t *testing.T) {
|
||||
f := newFakeSigner(t) // always reports ES256
|
||||
if _, err := buildSignedJWT(context.Background(), f, keysigner.KeyRef{}, keysigner.AlgRS256, nil, nil); err == nil {
|
||||
t.Fatal("expected error when header alg != signer alg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSignedJWT_MarshalErrors(t *testing.T) {
|
||||
f := newFakeSigner(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := buildSignedJWT(ctx, f, keysigner.KeyRef{}, keysigner.AlgES256,
|
||||
map[string]any{"bad": func() {}}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "jwt: marshal header") {
|
||||
t.Fatalf("header marshal error = %v, want prefix %q", err, "jwt: marshal header")
|
||||
}
|
||||
|
||||
_, err = buildSignedJWT(ctx, f, keysigner.KeyRef{}, keysigner.AlgES256,
|
||||
nil, map[string]any{"bad": make(chan int)})
|
||||
if err == nil || !strings.Contains(err.Error(), "jwt: marshal claims") {
|
||||
t.Fatalf("claims marshal error = %v, want prefix %q", err, "jwt: marshal claims")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignClientAssertion(t *testing.T) {
|
||||
f := newFakeSigner(t)
|
||||
now := time.Unix(1700000000, 0)
|
||||
const aud = "https://accounts.feishu.cn/open-apis/authen/v2/oauth/token"
|
||||
|
||||
tok, err := SignClientAssertion(context.Background(), f, keysigner.KeyRef{Label: "k"}, "cli_app", aud, now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
parts := strings.Split(tok, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("want 3 parts, got %d", len(parts))
|
||||
}
|
||||
cb, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(cb, &claims); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if claims["iss"] != "cli_app" || claims["aud"] != aud {
|
||||
t.Errorf("claims = %v", claims)
|
||||
}
|
||||
|
||||
// Signature must verify against the key's public half.
|
||||
sig, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
r := new(big.Int).SetBytes(sig[:32])
|
||||
s := new(big.Int).SetBytes(sig[32:])
|
||||
h := sha256.Sum256([]byte(parts[0] + "." + parts[1]))
|
||||
if !ecdsa.Verify(f.key.Public().(*ecdsa.PublicKey), h[:], r, s) {
|
||||
t.Error("client_assertion signature did not verify")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignClientAssertion_NilSigner(t *testing.T) {
|
||||
if _, err := SignClientAssertion(context.Background(), nil, keysigner.KeyRef{}, "cli_app", "aud", time.Unix(0, 0)); err == nil {
|
||||
t.Fatal("expected error for nil signer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAttestation(t *testing.T) {
|
||||
f := newFakeSigner(t)
|
||||
now := time.Unix(1700000000, 0)
|
||||
|
||||
tok, err := SignAttestation(context.Background(), f, keysigner.KeyRef{Label: "k"}, "nonce-abc", now)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parts := strings.Split(tok, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("want 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
hb, _ := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
var hdr map[string]any
|
||||
if err := json.Unmarshal(hb, &hdr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jwk, ok := hdr["jwk"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("attestation header missing jwk: %v", hdr)
|
||||
}
|
||||
if jwk["kty"] != "EC" || jwk["crv"] != "P-256" || jwk["use"] != "sig" {
|
||||
t.Errorf("jwk = %v", jwk)
|
||||
}
|
||||
|
||||
cb, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(cb, &claims); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if claims["nonce"] != "nonce-abc" {
|
||||
t.Errorf("nonce = %v", claims["nonce"])
|
||||
}
|
||||
// jti, iat, exp are all required by the attestation spec.
|
||||
iat, iatOK := claims["iat"].(float64)
|
||||
exp, expOK := claims["exp"].(float64)
|
||||
if !iatOK || !expOK || exp <= iat {
|
||||
t.Errorf("claims iat/exp invalid: iat=%v exp=%v", claims["iat"], claims["exp"])
|
||||
}
|
||||
if jti, _ := claims["jti"].(string); jti == "" {
|
||||
t.Error("claims jti empty")
|
||||
}
|
||||
|
||||
// Signature verifies against the embedded key.
|
||||
sig, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
r := new(big.Int).SetBytes(sig[:32])
|
||||
s := new(big.Int).SetBytes(sig[32:])
|
||||
h := sha256.Sum256([]byte(parts[0] + "." + parts[1]))
|
||||
if !ecdsa.Verify(f.key.Public().(*ecdsa.PublicKey), h[:], r, s) {
|
||||
t.Error("attestation signature did not verify")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAttestation_NilSigner(t *testing.T) {
|
||||
if _, err := SignAttestation(context.Background(), nil, keysigner.KeyRef{}, "n", time.Unix(0, 0)); err == nil {
|
||||
t.Fatal("expected error for nil signer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimFactories(t *testing.T) {
|
||||
now := time.Unix(1700000000, 0)
|
||||
|
||||
a := attestationClaims("nonce-xyz", now)
|
||||
if a["nonce"] != "nonce-xyz" || a["iat"] != now.Unix() {
|
||||
t.Errorf("attestation claims = %v", a)
|
||||
}
|
||||
if a["exp"] != now.Add(attestationTTL).Unix() {
|
||||
t.Errorf("attestation exp = %v, want %v", a["exp"], now.Add(attestationTTL).Unix())
|
||||
}
|
||||
if jti, _ := a["jti"].(string); jti == "" {
|
||||
t.Error("attestation jti empty")
|
||||
}
|
||||
|
||||
c := clientAssertionClaims("cli_app", "aud", now, time.Minute)
|
||||
if c["exp"].(int64) != now.Add(time.Minute).Unix() {
|
||||
t.Errorf("client_assertion exp = %v", c["exp"])
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/errclass"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
)
|
||||
|
||||
@@ -37,7 +38,10 @@ type UATCallOptions struct {
|
||||
AppId string
|
||||
AppSecret string
|
||||
Domain core.LarkBrand
|
||||
ErrOut io.Writer // diagnostic/status output (caller injects f.IOStreams.ErrOut)
|
||||
AuthMethod string // "" == client_secret; core.AuthMethodPrivateKeyJWT
|
||||
KeyLabel string // TEE key handle for private_key_jwt
|
||||
Signer keysigner.Signer // active signer for private_key_jwt
|
||||
ErrOut io.Writer // diagnostic/status output (caller injects f.IOStreams.ErrOut)
|
||||
}
|
||||
|
||||
// UATStatus represents the status of a user access token.
|
||||
@@ -61,6 +65,9 @@ func NewUATCallOptions(cfg *core.CliConfig, errOut io.Writer) UATCallOptions {
|
||||
AppId: cfg.AppID,
|
||||
AppSecret: cfg.AppSecret,
|
||||
Domain: cfg.Brand,
|
||||
AuthMethod: cfg.AuthMethod,
|
||||
KeyLabel: cfg.KeyLabel,
|
||||
Signer: keysigner.Active(),
|
||||
ErrOut: errOut,
|
||||
}
|
||||
}
|
||||
@@ -193,7 +200,14 @@ func doRefreshToken(httpClient *http.Client, opts UATCallOptions, stored *Stored
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", stored.RefreshToken)
|
||||
form.Set("client_id", opts.AppId)
|
||||
form.Set("client_secret", opts.AppSecret)
|
||||
ca := ClientAuth{AppID: opts.AppId, AppSecret: opts.AppSecret, AuthMethod: opts.AuthMethod, Signer: opts.Signer, KeyLabel: opts.KeyLabel}
|
||||
usedAssertion, caErr := ca.applyClientAssertion(context.Background(), form, core.OpenAPIAudience(opts.Domain))
|
||||
if caErr != nil {
|
||||
return nil, caErr
|
||||
}
|
||||
if !usedAssertion {
|
||||
form.Set("client_secret", opts.AppSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", endpoints.Token, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
|
||||
@@ -38,3 +38,23 @@ func TestNewUATCallOptions(t *testing.T) {
|
||||
t.Error("ErrOut not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewUATCallOptions_PrivateKeyJWT verifies the auth-method fields propagate
|
||||
// so the refresh path can mint a client_assertion instead of sending a secret.
|
||||
func TestNewUATCallOptions_PrivateKeyJWT(t *testing.T) {
|
||||
cfg := &core.CliConfig{
|
||||
AppID: "cli_pk",
|
||||
Brand: core.BrandFeishu,
|
||||
UserOpenId: "ou_test",
|
||||
AuthMethod: core.AuthMethodPrivateKeyJWT,
|
||||
KeyLabel: "agent-key",
|
||||
}
|
||||
opts := NewUATCallOptions(cfg, &bytes.Buffer{})
|
||||
|
||||
if opts.AuthMethod != core.AuthMethodPrivateKeyJWT {
|
||||
t.Errorf("AuthMethod = %q, want private_key_jwt", opts.AuthMethod)
|
||||
}
|
||||
if opts.KeyLabel != "agent-key" {
|
||||
t.Errorf("KeyLabel = %q, want agent-key", opts.KeyLabel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,16 @@ func NewIOStreams(in io.Reader, out, errOut io.Writer) *IOStreams {
|
||||
return &IOStreams{In: in, Out: out, ErrOut: errOut, IsTerminal: isTerminal, StderrIsTerminal: stderrIsTerminal}
|
||||
}
|
||||
|
||||
// StdoutIsTerminal reports whether Out is an interactive terminal. Unlike
|
||||
// IsTerminal — which reflects stdin and drives prompt decisions — this is the
|
||||
// correct check for OUTPUT formatting: `cmd | jq` must still emit machine output
|
||||
// from an interactive shell (stdin is a TTY there, but stdout is the pipe).
|
||||
// Buffers (tests) and redirects are not *os.File terminals, so they yield false.
|
||||
func (s *IOStreams) StdoutIsTerminal() bool {
|
||||
f, ok := s.Out.(*os.File)
|
||||
return ok && term.IsTerminal(int(f.Fd()))
|
||||
}
|
||||
|
||||
// SystemIO creates an IOStreams wired to the process's standard file descriptors.
|
||||
//
|
||||
//nolint:forbidigo // entry point for real stdio
|
||||
|
||||
28
internal/cmdutil/iostreams_test.go
Normal file
28
internal/cmdutil/iostreams_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStdoutIsTerminal(t *testing.T) {
|
||||
// Buffer-backed output (tests, captured output) is never a terminal.
|
||||
if (&IOStreams{Out: &bytes.Buffer{}}).StdoutIsTerminal() {
|
||||
t.Error("bytes.Buffer Out should not be a terminal")
|
||||
}
|
||||
// An os.Pipe write end is an *os.File but not a terminal — mirrors `cmd | jq`,
|
||||
// the case the stdin-based IsTerminal would get wrong.
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
defer w.Close()
|
||||
if (&IOStreams{Out: w}).StdoutIsTerminal() {
|
||||
t.Error("os.Pipe Out should not be a terminal")
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,13 @@ type AppUser struct {
|
||||
UserName string `json:"userName"`
|
||||
}
|
||||
|
||||
// Auth methods for app credentials. An empty AppConfig.AuthMethod means the
|
||||
// default, client_secret.
|
||||
const (
|
||||
AuthMethodClientSecret = "client_secret" // app_id + app_secret
|
||||
AuthMethodPrivateKeyJWT = "private_key_jwt" // TEE-signed client_assertion; no app secret
|
||||
)
|
||||
|
||||
// AppConfig is a per-app configuration entry (stored format — secrets may be unresolved).
|
||||
type AppConfig struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -46,6 +53,15 @@ type AppConfig struct {
|
||||
DefaultAs Identity `json:"defaultAs,omitempty"` // AsUser | AsBot | AsAuto
|
||||
StrictMode *StrictMode `json:"strictMode,omitempty"`
|
||||
Users []AppUser `json:"users"`
|
||||
|
||||
// AuthMethod selects how tokens are minted. Empty == AuthMethodClientSecret
|
||||
// (back-compat). AuthMethodPrivateKeyJWT uses a TEE-held key (see KeyRef) to
|
||||
// sign client_assertion JWTs instead of sending an app secret.
|
||||
AuthMethod string `json:"authMethod,omitempty"`
|
||||
// KeyRef references the non-exportable signing key for private_key_jwt.
|
||||
// Source is "tee" and ID is the backend key label; the actual key never
|
||||
// leaves the secure backend, so this is a handle, not secret material.
|
||||
KeyRef *SecretRef `json:"keyRef,omitempty"`
|
||||
}
|
||||
|
||||
// ProfileName returns the display name for this app config.
|
||||
@@ -161,7 +177,9 @@ type CliConfig struct {
|
||||
UserOpenId string
|
||||
UserName string
|
||||
Lang i18n.Lang
|
||||
SupportedIdentities uint8 `json:"-"` // bitflag: 1=user, 2=bot; set by credential provider
|
||||
SupportedIdentities uint8 `json:"-"` // bitflag: 1=user, 2=bot; set by credential provider
|
||||
AuthMethod string // "" == client_secret; AuthMethodPrivateKeyJWT
|
||||
KeyLabel string // resolved TEE key handle for private_key_jwt
|
||||
}
|
||||
|
||||
// identityBotBit is the bit flag for bot identity in SupportedIdentities.
|
||||
@@ -247,31 +265,58 @@ func ResolveConfigFromMulti(raw *MultiAppConfig, kc keychain.KeychainAccess, pro
|
||||
WithHint("available profiles: %s", formatProfileNames(raw.ProfileNames()))
|
||||
}
|
||||
|
||||
if err := ValidateSecretKeyMatch(app.AppId, app.AppSecret); err != nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeNotConfigured, "appId and appSecret keychain key are out of sync").
|
||||
WithHint("%s", err.Error()).
|
||||
WithCause(err)
|
||||
// Validate the auth method first so a malformed profile fails here rather
|
||||
// than silently degrading to client_secret (unknown method) or failing later
|
||||
// at token-signing. Empty stays empty — downstream treats it as client_secret
|
||||
// (back-compat).
|
||||
switch app.AuthMethod {
|
||||
case "", AuthMethodClientSecret, AuthMethodPrivateKeyJWT:
|
||||
default:
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidConfig, "unknown authMethod %q", app.AuthMethod).
|
||||
WithHint("supported: %s, %s (empty defaults to %s)", AuthMethodClientSecret, AuthMethodPrivateKeyJWT, AuthMethodClientSecret)
|
||||
}
|
||||
|
||||
secret, err := ResolveSecretInput(app.AppSecret, kc)
|
||||
if err != nil {
|
||||
if errs.IsTyped(err) {
|
||||
return nil, err
|
||||
// private_key_jwt carries no secret: validate the key handle and skip secret
|
||||
// resolution entirely, so a stale/broken AppSecret ref never produces a
|
||||
// confusing secret-resolution error for an otherwise-valid pkjwt profile.
|
||||
var secret string
|
||||
if app.AuthMethod == AuthMethodPrivateKeyJWT {
|
||||
if app.KeyRef == nil || app.KeyRef.Source != "tee" || app.KeyRef.ID == "" {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidConfig, "private_key_jwt requires a key handle (keyRef) but none is configured").
|
||||
WithHint("re-run: lark-cli config init --new --auth-method private_key_jwt")
|
||||
}
|
||||
subtype := errs.SubtypeNotConfigured
|
||||
if isMalformedConfigError(err) {
|
||||
subtype = errs.SubtypeInvalidConfig
|
||||
} else {
|
||||
if err := ValidateSecretKeyMatch(app.AppId, app.AppSecret); err != nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeNotConfigured, "appId and appSecret keychain key are out of sync").
|
||||
WithHint("%s", err.Error()).
|
||||
WithCause(err)
|
||||
}
|
||||
var resolveErr error
|
||||
secret, resolveErr = ResolveSecretInput(app.AppSecret, kc)
|
||||
if resolveErr != nil {
|
||||
if errs.IsTyped(resolveErr) {
|
||||
return nil, resolveErr
|
||||
}
|
||||
subtype := errs.SubtypeNotConfigured
|
||||
if isMalformedConfigError(resolveErr) {
|
||||
subtype = errs.SubtypeInvalidConfig
|
||||
}
|
||||
return nil, errs.NewConfigError(subtype, "%s", resolveErr.Error()).WithCause(resolveErr)
|
||||
}
|
||||
return nil, errs.NewConfigError(subtype, "%s", err.Error()).WithCause(err)
|
||||
}
|
||||
|
||||
cfg := &CliConfig{
|
||||
ProfileName: app.ProfileName(),
|
||||
AppID: app.AppId,
|
||||
AppSecret: secret,
|
||||
Brand: app.Brand,
|
||||
Lang: app.Lang,
|
||||
AuthMethod: app.AuthMethod,
|
||||
DefaultAs: app.DefaultAs,
|
||||
}
|
||||
if app.KeyRef != nil {
|
||||
cfg.KeyLabel = app.KeyRef.ID
|
||||
}
|
||||
if len(app.Users) > 0 {
|
||||
cfg.UserOpenId = app.Users[0].UserOpenId
|
||||
cfg.UserName = app.Users[0].UserName
|
||||
|
||||
@@ -133,6 +133,108 @@ func TestResolveConfigFromMulti_AcceptsPlainSecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveConfigFromMulti_RejectsUnknownAuthMethod ensures an unsupported
|
||||
// authMethod fails at resolution rather than silently degrading to client_secret.
|
||||
func TestResolveConfigFromMulti_RejectsUnknownAuthMethod(t *testing.T) {
|
||||
raw := &MultiAppConfig{
|
||||
Apps: []AppConfig{
|
||||
{
|
||||
AppId: "cli_abc",
|
||||
AppSecret: PlainSecret("my-secret"),
|
||||
Brand: BrandFeishu,
|
||||
AuthMethod: "bogus_method",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ResolveConfigFromMulti(raw, nil, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown authMethod")
|
||||
}
|
||||
var cfgErr *errs.ConfigError
|
||||
if !errors.As(err, &cfgErr) {
|
||||
t.Fatalf("expected ConfigError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveConfigFromMulti_PrivateKeyJWTRequiresKeyRef ensures private_key_jwt
|
||||
// without a key handle fails at resolution rather than later at token-signing.
|
||||
func TestResolveConfigFromMulti_PrivateKeyJWTRequiresKeyRef(t *testing.T) {
|
||||
raw := &MultiAppConfig{
|
||||
Apps: []AppConfig{
|
||||
{
|
||||
AppId: "cli_abc",
|
||||
AppSecret: SecretInput{}, // private_key_jwt carries no app secret
|
||||
Brand: BrandFeishu,
|
||||
AuthMethod: AuthMethodPrivateKeyJWT,
|
||||
// KeyRef intentionally nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ResolveConfigFromMulti(raw, nil, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for private_key_jwt without keyRef")
|
||||
}
|
||||
var cfgErr *errs.ConfigError
|
||||
if !errors.As(err, &cfgErr) {
|
||||
t.Fatalf("expected ConfigError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// Control: same config WITH a keyRef resolves cleanly and sets KeyLabel.
|
||||
raw.Apps[0].KeyRef = &SecretRef{Source: "tee", ID: "larksuite-cli-agent"}
|
||||
cfg, err := ResolveConfigFromMulti(raw, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with keyRef present: %v", err)
|
||||
}
|
||||
if cfg.KeyLabel != "larksuite-cli-agent" {
|
||||
t.Errorf("KeyLabel = %q, want larksuite-cli-agent", cfg.KeyLabel)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveConfigFromMulti_PKJWTSkipsSecretResolution ensures a private_key_jwt
|
||||
// profile that carries a stale/broken AppSecret ref still resolves cleanly: the
|
||||
// auth method is judged before any secret handling, so the stale ref is ignored
|
||||
// instead of producing a confusing secret-resolution failure.
|
||||
func TestResolveConfigFromMulti_PKJWTSkipsSecretResolution(t *testing.T) {
|
||||
raw := &MultiAppConfig{
|
||||
Apps: []AppConfig{{
|
||||
AppId: "cli_pk",
|
||||
// Stale keychain ref whose ID does not match appId — would trip
|
||||
// ValidateSecretKeyMatch / ResolveSecretInput if it were reached.
|
||||
AppSecret: SecretInput{Ref: &SecretRef{Source: "keychain", ID: "appsecret:cli_OTHER"}},
|
||||
Brand: BrandFeishu,
|
||||
AuthMethod: AuthMethodPrivateKeyJWT,
|
||||
KeyRef: &SecretRef{Source: "tee", ID: "agent-key"},
|
||||
Users: []AppUser{},
|
||||
}},
|
||||
}
|
||||
cfg, err := ResolveConfigFromMulti(raw, stubKeychain{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("pkjwt with stale secret ref must skip secret resolution, got %v", err)
|
||||
}
|
||||
if cfg.AuthMethod != AuthMethodPrivateKeyJWT || cfg.KeyLabel != "agent-key" {
|
||||
t.Errorf("got authMethod=%q keyLabel=%q", cfg.AuthMethod, cfg.KeyLabel)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveConfigFromMulti_PKJWTRejectsBadKeyRef ensures the stricter keyRef
|
||||
// check (Source=="tee" && ID!="") rejects malformed handles.
|
||||
func TestResolveConfigFromMulti_PKJWTRejectsBadKeyRef(t *testing.T) {
|
||||
for i, ref := range []*SecretRef{
|
||||
{Source: "keychain", ID: "x"}, // wrong source
|
||||
{Source: "tee", ID: ""}, // empty id
|
||||
} {
|
||||
raw := &MultiAppConfig{Apps: []AppConfig{{
|
||||
AppId: "cli_pk", Brand: BrandFeishu,
|
||||
AuthMethod: AuthMethodPrivateKeyJWT, KeyRef: ref, Users: []AppUser{},
|
||||
}}}
|
||||
if _, err := ResolveConfigFromMulti(raw, stubKeychain{}, ""); err == nil {
|
||||
t.Errorf("case %d: expected ConfigError for bad keyRef", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveConfigFromMulti_CarriesLang(t *testing.T) {
|
||||
raw := &MultiAppConfig{
|
||||
Apps: []AppConfig{
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
package core
|
||||
|
||||
import "strings"
|
||||
|
||||
// LarkBrand represents the Lark platform brand.
|
||||
// "feishu" targets China-mainland, "lark" targets international.
|
||||
// Any other string is treated as a custom base URL.
|
||||
@@ -60,3 +62,10 @@ func ResolveEndpoints(brand LarkBrand) Endpoints {
|
||||
func ResolveOpenBaseURL(brand LarkBrand) string {
|
||||
return ResolveEndpoints(brand).Open
|
||||
}
|
||||
|
||||
// OpenAPIAudience returns the client_assertion `aud` value for the brand: the
|
||||
// bare Open API host per the App Authentication JWT spec — "open.feishu.cn" or
|
||||
// "open.larksuite.com" — not the full token endpoint URL.
|
||||
func OpenAPIAudience(brand LarkBrand) string {
|
||||
return strings.TrimPrefix(ResolveOpenBaseURL(brand), "https://")
|
||||
}
|
||||
|
||||
@@ -57,3 +57,12 @@ func TestResolveOpenBaseURL(t *testing.T) {
|
||||
t.Errorf("ResolveOpenBaseURL(lark) = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIAudience(t *testing.T) {
|
||||
if got := OpenAPIAudience(BrandFeishu); got != "open.feishu.cn" {
|
||||
t.Errorf("OpenAPIAudience(feishu) = %q, want open.feishu.cn", got)
|
||||
}
|
||||
if got := OpenAPIAudience(BrandLark); got != "open.larksuite.com" {
|
||||
t.Errorf("OpenAPIAudience(lark) = %q, want open.larksuite.com", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/larksuite/cli/internal/keychain"
|
||||
|
||||
extcred "github.com/larksuite/cli/extension/credential"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// classifyTATResponseCode wraps a deterministic (non-transient) failure from the
|
||||
@@ -175,6 +176,23 @@ func (p *DefaultTokenProvider) doResolveTAT(ctx context.Context) (*TokenResult,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// private_key_jwt apps have no app secret: mint via the jwt-bearer grant
|
||||
// using a TEE-signed client_assertion instead.
|
||||
if acct.AuthMethod == core.AuthMethodPrivateKeyJWT {
|
||||
signer := keysigner.Active()
|
||||
if signer == nil {
|
||||
return nil, errs.NewConfigError(errs.SubtypeInvalidClient,
|
||||
"profile uses private_key_jwt but no TEE key signer is available on this build").
|
||||
WithHint("install a build with the platform key-signer extension, or reconfigure the app to use an app secret")
|
||||
}
|
||||
token, err := FetchTATWithAssertion(ctx, httpClient, acct.Brand, acct.AppID, signer, acct.KeyLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenResult{Token: token}, nil
|
||||
}
|
||||
|
||||
token, err := FetchTAT(ctx, httpClient, acct.Brand, acct.AppID, acct.AppSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -11,8 +11,13 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/auth/jwt"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// FetchTAT performs a single HTTP POST to mint a tenant access token via the
|
||||
@@ -100,3 +105,96 @@ func FetchTAT(ctx context.Context, httpClient *http.Client, brand core.LarkBrand
|
||||
}
|
||||
return "", classifyTATResponseCode(result.Code, result.Error, desc, string(brand), appID)
|
||||
}
|
||||
|
||||
// FetchTATWithAssertion mints a tenant access token for a private_key_jwt app via
|
||||
// the RFC 7523 jwt-bearer grant: it signs a short-lived client_assertion with the
|
||||
// TEE-held key and posts it to the unified OAuth token endpoint, replacing the
|
||||
// app_secret entirely.
|
||||
//
|
||||
// The unified v2 token endpoint returns the minted token as access_token
|
||||
// (tenant_access_token is accepted as a fallback).
|
||||
func FetchTATWithAssertion(ctx context.Context, httpClient *http.Client, brand core.LarkBrand, clientID string, signer keysigner.Signer, keyLabel string) (string, error) {
|
||||
if signer == nil {
|
||||
return "", fmt.Errorf("private_key_jwt requires a key signer, but none is available on this build")
|
||||
}
|
||||
ep := core.ResolveEndpoints(brand)
|
||||
endpoint := ep.Open + auth.PathOAuthTokenV2
|
||||
|
||||
assertion, err := jwt.SignClientAssertion(ctx, signer, keysigner.KeyRef{Label: keyLabel}, clientID, core.OpenAPIAudience(brand), time.Now())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
form.Set("client_id", clientID)
|
||||
form.Set("client_assertion_type", jwt.ClientAssertionType)
|
||||
form.Set("client_assertion", assertion)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read token response: %w", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TenantAccessToken string `json:"tenant_access_token"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &result) // best-effort; error body may not be JSON
|
||||
|
||||
token := result.AccessToken
|
||||
if token == "" {
|
||||
token = result.TenantAccessToken
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK && token != "" && result.Error == "" && result.Code == 0 {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Surface the server's reason, preferring the OAuth `error` code (e.g.
|
||||
// unauthorized_client) which is more diagnostic than the description alone.
|
||||
detail := result.ErrorDescription
|
||||
if detail == "" {
|
||||
detail = result.Msg
|
||||
}
|
||||
if detail == "" {
|
||||
detail = strings.TrimSpace(string(body))
|
||||
}
|
||||
if result.Error != "" {
|
||||
return "", classifyAssertionError(result.Error, resp.StatusCode, detail)
|
||||
}
|
||||
return "", fmt.Errorf("token endpoint HTTP %d (code=%d): %s", resp.StatusCode, result.Code, detail)
|
||||
}
|
||||
|
||||
// classifyAssertionError maps the OAuth token endpoint's `error` field to a
|
||||
// typed or untyped error. Only deterministic client-credential rejections get a
|
||||
// typed errs.ConfigError (so runProbePKJWT can tell "this key is not bound to
|
||||
// this app" apart from upstream noise); every other error (e.g.
|
||||
// temporarily_unavailable) stays untyped and is swallowed by the probe. detail
|
||||
// carries only the server's error_description / msg / body text — it never
|
||||
// echoes the client_assertion or private key (the assertion lives only in the
|
||||
// request form).
|
||||
func classifyAssertionError(oauthError string, httpStatus int, detail string) error {
|
||||
switch oauthError {
|
||||
case "invalid_client", "unauthorized_client", "invalid_grant":
|
||||
return errs.NewConfigError(errs.SubtypeInvalidClient,
|
||||
"token endpoint rejected the key (%s): %s", oauthError, detail)
|
||||
default:
|
||||
return fmt.Errorf("token endpoint HTTP %d (%s): %s", httpStatus, oauthError, detail)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,24 @@ package credential
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/keysigner"
|
||||
)
|
||||
|
||||
// stubRoundTripper lets us assert request shape and return canned responses.
|
||||
@@ -307,3 +316,147 @@ func (r *urlRewriteRT) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req2.Header = req.Header
|
||||
return http.DefaultTransport.RoundTrip(req2)
|
||||
}
|
||||
|
||||
// fakeTATSigner is a real in-memory ECDSA P-256 signer for assertion tests.
|
||||
type fakeTATSigner struct{ key *ecdsa.PrivateKey }
|
||||
|
||||
func newFakeTATSigner(t *testing.T) *fakeTATSigner {
|
||||
t.Helper()
|
||||
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &fakeTATSigner{key: k}
|
||||
}
|
||||
|
||||
func (f *fakeTATSigner) EnsureKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeTATSigner) PublicKey(context.Context, keysigner.KeyRef) (crypto.PublicKey, error) {
|
||||
return f.key.Public(), nil
|
||||
}
|
||||
func (f *fakeTATSigner) Sign(_ context.Context, _ keysigner.KeyRef, in []byte) ([]byte, string, error) {
|
||||
h := sha256.Sum256(in)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, f.key, h[:])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
sig := make([]byte, 64)
|
||||
r.FillBytes(sig[:32])
|
||||
s.FillBytes(sig[32:])
|
||||
return sig, keysigner.AlgES256, nil
|
||||
}
|
||||
|
||||
func TestFetchTATWithAssertion_Success(t *testing.T) {
|
||||
rt := &stubRoundTripper{respCode: 200, respBody: `{"access_token":"t-jwt","token_type":"Bearer","expires_in":7200}`}
|
||||
hc := &http.Client{Transport: rt}
|
||||
|
||||
token, err := FetchTATWithAssertion(context.Background(), hc, core.BrandFeishu, "cli_app", newFakeTATSigner(t), "agent-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "t-jwt" {
|
||||
t.Errorf("token = %q, want t-jwt", token)
|
||||
}
|
||||
if rt.gotReq.URL.String() != "https://open.feishu.cn/open-apis/authen/v2/oauth/token" {
|
||||
t.Errorf("url = %s", rt.gotReq.URL.String())
|
||||
}
|
||||
|
||||
form, err := url.ParseQuery(rt.gotBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if form.Get("grant_type") != "urn:ietf:params:oauth:grant-type:jwt-bearer" {
|
||||
t.Errorf("grant_type = %q", form.Get("grant_type"))
|
||||
}
|
||||
if form.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
|
||||
t.Errorf("client_assertion_type = %q", form.Get("client_assertion_type"))
|
||||
}
|
||||
if form.Get("client_assertion") == "" {
|
||||
t.Error("client_assertion is empty")
|
||||
}
|
||||
if form.Has("client_secret") {
|
||||
t.Error("client_secret must NOT be sent for private_key_jwt")
|
||||
}
|
||||
|
||||
// The assertion's aud must be the bare Open host per the App Authentication
|
||||
// JWT spec — not the full token endpoint URL.
|
||||
jwtParts := strings.Split(form.Get("client_assertion"), ".")
|
||||
if len(jwtParts) != 3 {
|
||||
t.Fatalf("malformed client_assertion: %q", form.Get("client_assertion"))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(jwtParts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("assertion payload not base64url: %v", err)
|
||||
}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if claims["aud"] != "open.feishu.cn" {
|
||||
t.Errorf("client_assertion aud = %v, want open.feishu.cn", claims["aud"])
|
||||
}
|
||||
if claims["iss"] != "cli_app" || claims["sub"] != "cli_app" {
|
||||
t.Errorf("client_assertion iss/sub = %v/%v, want cli_app", claims["iss"], claims["sub"])
|
||||
}
|
||||
if form.Get("client_id") != "cli_app" {
|
||||
t.Errorf("client_id = %q", form.Get("client_id"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchTATWithAssertion_NilSigner(t *testing.T) {
|
||||
hc := &http.Client{Transport: &stubRoundTripper{respCode: 200, respBody: `{}`}}
|
||||
if _, err := FetchTATWithAssertion(context.Background(), hc, core.BrandFeishu, "cli_app", nil, "k"); err == nil {
|
||||
t.Fatal("expected error when signer is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchTATWithAssertion_ServerError(t *testing.T) {
|
||||
rt := &stubRoundTripper{respCode: 200, respBody: `{"error":"invalid_client","error_description":"unknown key"}`}
|
||||
hc := &http.Client{Transport: rt}
|
||||
if _, err := FetchTATWithAssertion(context.Background(), hc, core.BrandFeishu, "cli_app", newFakeTATSigner(t), "k"); err == nil {
|
||||
t.Fatal("expected error for invalid_client response")
|
||||
}
|
||||
}
|
||||
|
||||
// Deterministic OAuth client rejections must be typed (ConfigError /
|
||||
// SubtypeInvalidClient) so runProbePKJWT can tell "the key is not bound to this
|
||||
// app" apart from transport noise.
|
||||
func TestFetchTATWithAssertion_DeterministicReject_Typed(t *testing.T) {
|
||||
for _, oauthErr := range []string{"invalid_client", "unauthorized_client", "invalid_grant"} {
|
||||
rt := &stubRoundTripper{respCode: 401, respBody: `{"error":"` + oauthErr + `","error_description":"bad key"}`}
|
||||
hc := &http.Client{Transport: rt}
|
||||
_, err := FetchTATWithAssertion(context.Background(), hc, core.BrandFeishu, "cli_app", newFakeTATSigner(t), "k")
|
||||
if err == nil {
|
||||
t.Fatalf("%s: expected error", oauthErr)
|
||||
}
|
||||
if !errs.IsTyped(err) {
|
||||
t.Errorf("%s: must be typed, got %T", oauthErr, err)
|
||||
}
|
||||
var cfgErr *errs.ConfigError
|
||||
if !errors.As(err, &cfgErr) || cfgErr.Subtype != errs.SubtypeInvalidClient {
|
||||
t.Errorf("%s: want ConfigError/InvalidClient, got %T %v", oauthErr, err, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unrecognized OAuth errors and non-payload noise stay UNTYPED so the probe
|
||||
// treats them as upstream noise and stays silent.
|
||||
func TestFetchTATWithAssertion_AmbiguousError_Untyped(t *testing.T) {
|
||||
cases := []string{
|
||||
`{"error":"temporarily_unavailable","error_description":"retry"}`,
|
||||
`{"code":99999,"msg":"weird"}`,
|
||||
`not json`,
|
||||
}
|
||||
for _, body := range cases {
|
||||
rt := &stubRoundTripper{respCode: 503, respBody: body}
|
||||
hc := &http.Client{Transport: rt}
|
||||
_, err := FetchTATWithAssertion(context.Background(), hc, core.BrandFeishu, "cli_app", newFakeTATSigner(t), "k")
|
||||
if err == nil {
|
||||
t.Fatalf("body %q: expected error", body)
|
||||
}
|
||||
if errs.IsTyped(err) {
|
||||
t.Errorf("body %q: must be UNTYPED, got typed %T", body, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ type Account struct {
|
||||
UserName string
|
||||
Lang i18n.Lang
|
||||
SupportedIdentities uint8
|
||||
AuthMethod string // "" == client_secret; core.AuthMethodPrivateKeyJWT
|
||||
KeyLabel string // resolved TEE key handle for private_key_jwt
|
||||
}
|
||||
|
||||
const runtimePlaceholderAppSecret = "__LARKSUITE_CLI_TOKEN_ONLY__"
|
||||
@@ -69,6 +71,8 @@ func AccountFromCliConfig(cfg *core.CliConfig) *Account {
|
||||
UserName: cfg.UserName,
|
||||
Lang: cfg.Lang,
|
||||
SupportedIdentities: cfg.SupportedIdentities,
|
||||
AuthMethod: cfg.AuthMethod,
|
||||
KeyLabel: cfg.KeyLabel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +91,8 @@ func (a *Account) ToCliConfig() *core.CliConfig {
|
||||
UserName: a.UserName,
|
||||
Lang: a.Lang,
|
||||
SupportedIdentities: a.SupportedIdentities,
|
||||
AuthMethod: a.AuthMethod,
|
||||
KeyLabel: a.KeyLabel,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,9 @@ func diagnoseBot(ctx context.Context, f *cmdutil.Factory, cfg *core.CliConfig, v
|
||||
Hint: "check strict mode or the active credential provider",
|
||||
}
|
||||
}
|
||||
if cfg.SupportedIdentities == 0 && !credential.HasRealAppSecret(cfg.AppSecret) {
|
||||
// private_key_jwt apps have no app secret — the bot/tenant token is minted via
|
||||
// a TEE-signed client_assertion — so absence of a secret is NOT "unconfigured".
|
||||
if cfg.SupportedIdentities == 0 && !credential.HasRealAppSecret(cfg.AppSecret) && cfg.AuthMethod != core.AuthMethodPrivateKeyJWT {
|
||||
return Identity{
|
||||
Status: StatusNotConfigured,
|
||||
Message: "Bot identity: not configured (missing app secret or bot token)",
|
||||
|
||||
212
internal/keysigner/keysigner.go
Normal file
212
internal/keysigner/keysigner.go
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Package keysigner defines the pluggable signing abstraction used by the
|
||||
// private_key_jwt registration and authentication flow.
|
||||
//
|
||||
// The open-source core only declares the Signer interface and pure-stdlib key
|
||||
// helpers. The platform implementations that hold a non-exportable private key
|
||||
// (TPM 2.0 via facebookincubator/sks on Linux/Windows, a non-extractable
|
||||
// Keychain key on macOS) live OUTSIDE this core — in a build-tagged module or
|
||||
// extension — and register themselves via Register from init(). This keeps
|
||||
// CGO-heavy and license-sensitive dependencies out of the open-source build.
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KeyRef identifies a non-exportable signing key held by a backend
|
||||
// (TEE/TPM/Keychain). It is a stable handle (label), never the key material.
|
||||
type KeyRef struct {
|
||||
// Label is the backend key label/tag (e.g. "larksuite-cli-agent").
|
||||
Label string
|
||||
}
|
||||
|
||||
// Signer signs JWS signing inputs with a non-exportable key.
|
||||
type Signer interface {
|
||||
// EnsureKey returns the public key for ref, creating the key if absent.
|
||||
EnsureKey(ctx context.Context, ref KeyRef) (crypto.PublicKey, error)
|
||||
// PublicKey returns the public key for ref without creating it.
|
||||
PublicKey(ctx context.Context, ref KeyRef) (crypto.PublicKey, error)
|
||||
// Sign signs signingInput and returns a JOSE-format signature plus the JWS
|
||||
// alg ("ES256"/"RS256"). Implementations apply the alg's hash and, for
|
||||
// ECDSA, MUST return the fixed-width r||s form required by RFC 7518 §3.4
|
||||
// (not ASN.1 DER), because the backend (TPM/Keychain) typically yields DER.
|
||||
Sign(ctx context.Context, ref KeyRef, signingInput []byte) (sig []byte, alg string, err error)
|
||||
}
|
||||
|
||||
// Supported JWS algorithms.
|
||||
const (
|
||||
AlgES256 = "ES256"
|
||||
AlgRS256 = "RS256"
|
||||
)
|
||||
|
||||
// DefaultKeyLabel is the backend key label lark-cli uses for its device signing
|
||||
// key. One non-exportable key is created on first private_key_jwt registration
|
||||
// and reused across subsequent app registrations on the same device.
|
||||
const DefaultKeyLabel = "larksuite-cli-agent"
|
||||
|
||||
// HardwareInfo describes the secure hardware backing a Signer, as reported by a
|
||||
// HardwareProber. It is advisory/diagnostic: it tells a user whether
|
||||
// private_key_jwt can use a real TEE on this device.
|
||||
type HardwareInfo struct {
|
||||
Backend string // backing technology, e.g. "tpm2" or "keychain"
|
||||
Available bool // the hardware is present and usable for signing
|
||||
VendorName string // hardware vendor/manufacturer, when known
|
||||
VendorInfo string // additional vendor detail, when known
|
||||
Reason string // when Available is false, a human-readable cause
|
||||
}
|
||||
|
||||
// HardwareProber is an optional capability a Signer may implement to report on
|
||||
// the secure hardware backing it (TPM/TEE vendor and availability) WITHOUT
|
||||
// creating or using a key. Probing never mutates key state.
|
||||
type HardwareProber interface {
|
||||
ProbeHardware(ctx context.Context) (HardwareInfo, error)
|
||||
}
|
||||
|
||||
// ProbeActiveHardware probes the active signer's secure hardware. ok is false
|
||||
// when there is no active signer or it does not implement HardwareProber — in
|
||||
// which case private_key_jwt is unsupported on this build. When ok is true, info
|
||||
// reports availability and, if unavailable, info.Reason explains why.
|
||||
func ProbeActiveHardware(ctx context.Context) (info HardwareInfo, ok bool, err error) {
|
||||
return probeHardware(ctx, Active())
|
||||
}
|
||||
|
||||
// probeHardware is the registry-independent core of ProbeActiveHardware, so it
|
||||
// can be unit-tested without touching the global signer.
|
||||
func probeHardware(ctx context.Context, s Signer) (HardwareInfo, bool, error) {
|
||||
p, ok := s.(HardwareProber)
|
||||
if !ok {
|
||||
return HardwareInfo{}, false, nil
|
||||
}
|
||||
info, err := p.ProbeHardware(ctx)
|
||||
return info, true, err
|
||||
}
|
||||
|
||||
// cleanProbeError renders err's message with redundant re-wraps collapsed. Some
|
||||
// backends (e.g. facebookincubator/sks) wrap an error twice with the SAME "%w"
|
||||
// prefix, yielding "P: P: cause"; this peels each outer layer whose only
|
||||
// contribution is to repeat the prefix already present in the wrapped error,
|
||||
// leaving a single "P: cause". A layer that adds genuinely new context is kept.
|
||||
func cleanProbeError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
msg := err.Error()
|
||||
for {
|
||||
inner := errors.Unwrap(err)
|
||||
if inner == nil {
|
||||
break
|
||||
}
|
||||
innerMsg := inner.Error()
|
||||
prefix, ok := strings.CutSuffix(msg, innerMsg)
|
||||
if !ok || prefix == "" || !strings.HasPrefix(innerMsg, prefix) {
|
||||
break
|
||||
}
|
||||
msg, err = innerMsg, inner
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// AlgForKey returns the JWS alg for a public key: EC P-256 -> ES256, RSA -> RS256.
|
||||
// The signer backend chooses the key type (the macOS keychain signer uses an
|
||||
// RSA-2048 key, hence RS256).
|
||||
func AlgForKey(pub crypto.PublicKey) (string, error) {
|
||||
switch k := pub.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if k.Curve == elliptic.P256() {
|
||||
return AlgES256, nil
|
||||
}
|
||||
return "", fmt.Errorf("keysigner: unsupported EC curve %q (only P-256/ES256)", k.Curve.Params().Name)
|
||||
case *rsa.PublicKey:
|
||||
return AlgRS256, nil
|
||||
default:
|
||||
return "", fmt.Errorf("keysigner: unsupported public key type %T", pub)
|
||||
}
|
||||
}
|
||||
|
||||
// ecdsaDERToJOSE converts an ASN.1 DER-encoded ECDSA signature — the form most
|
||||
// TEE/TPM backends emit (e.g. facebookincubator/sks marshals the TPM's r,s with
|
||||
// asn1.Marshal) — into the fixed-width r||s form JWS requires for ES256
|
||||
// (RFC 7518 §3.4). byteLen is the curve coordinate size (32 for P-256), so the
|
||||
// result is exactly 2*byteLen bytes with r and s each left-zero-padded.
|
||||
//
|
||||
// This is intentionally part of the pure-stdlib core (not a platform signer) so
|
||||
// it can be unit-tested with a software key on any machine, including TPM-less CI.
|
||||
func ecdsaDERToJOSE(der []byte, byteLen int) ([]byte, error) {
|
||||
var sig struct{ R, S *big.Int }
|
||||
rest, err := asn1.Unmarshal(der, &sig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: parse ECDSA DER signature: %w", err)
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("keysigner: %d trailing byte(s) after ECDSA DER signature", len(rest))
|
||||
}
|
||||
if sig.R == nil || sig.S == nil || sig.R.Sign() <= 0 || sig.S.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("keysigner: ECDSA signature has non-positive r/s")
|
||||
}
|
||||
// Guard before FillBytes, which panics if the scalar does not fit in byteLen.
|
||||
if sig.R.BitLen() > byteLen*8 || sig.S.BitLen() > byteLen*8 {
|
||||
return nil, fmt.Errorf("keysigner: ECDSA r/s exceeds %d-byte coordinate", byteLen)
|
||||
}
|
||||
out := make([]byte, 2*byteLen)
|
||||
sig.R.FillBytes(out[:byteLen])
|
||||
sig.S.FillBytes(out[byteLen:])
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EncodePublicKey marshals pub to PKIX DER and base64-encodes it (std encoding),
|
||||
// matching the public-key form the registration backend binds to the app.
|
||||
func EncodePublicKey(pub crypto.PublicKey) (string, error) {
|
||||
der, err := x509.MarshalPKIXPublicKey(pub)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("keysigner: encode public key: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(der), nil
|
||||
}
|
||||
|
||||
// PublicKeyJWK returns the RFC 7517 JSON Web Key for pub, used to embed the
|
||||
// public key in the attestation JWT's "jwk" header so the registration backend
|
||||
// can bind it to the app. EC keys use base64url fixed-width coordinates
|
||||
// (RFC 7518 §6.2.1); RSA keys use base64url-encoded modulus and exponent.
|
||||
func PublicKeyJWK(pub crypto.PublicKey) (map[string]any, error) {
|
||||
switch k := pub.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if k.Curve != elliptic.P256() {
|
||||
return nil, fmt.Errorf("keysigner: JWK supports EC P-256 only, got %q", k.Curve.Params().Name)
|
||||
}
|
||||
const coordLen = 32 // P-256 field element size
|
||||
x := make([]byte, coordLen)
|
||||
y := make([]byte, coordLen)
|
||||
k.X.FillBytes(x)
|
||||
k.Y.FillBytes(y)
|
||||
return map[string]any{
|
||||
"use": "sig",
|
||||
"kty": "EC",
|
||||
"crv": "P-256",
|
||||
"x": base64.RawURLEncoding.EncodeToString(x),
|
||||
"y": base64.RawURLEncoding.EncodeToString(y),
|
||||
}, nil
|
||||
case *rsa.PublicKey:
|
||||
return map[string]any{
|
||||
"use": "sig",
|
||||
"kty": "RSA",
|
||||
"n": base64.RawURLEncoding.EncodeToString(k.N.Bytes()),
|
||||
"e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(k.E)).Bytes()),
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("keysigner: unsupported public key type %T for JWK", pub)
|
||||
}
|
||||
}
|
||||
240
internal/keysigner/keysigner_test.go
Normal file
240
internal/keysigner/keysigner_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAlgForKey(t *testing.T) {
|
||||
ec, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if alg, err := AlgForKey(ec.Public()); err != nil || alg != AlgES256 {
|
||||
t.Errorf("P-256: alg=%q err=%v, want ES256/nil", alg, err)
|
||||
}
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if alg, err := AlgForKey(rsaKey.Public()); err != nil || alg != AlgRS256 {
|
||||
t.Errorf("RSA: alg=%q err=%v, want RS256/nil", alg, err)
|
||||
}
|
||||
|
||||
ec384, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := AlgForKey(ec384.Public()); err == nil {
|
||||
t.Error("P-384: expected unsupported-curve error")
|
||||
}
|
||||
|
||||
if _, err := AlgForKey("not a key"); err == nil {
|
||||
t.Error("string: expected unsupported-type error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodePublicKeyRoundTrip(t *testing.T) {
|
||||
ec, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
enc, err := EncodePublicKey(ec.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
der, err := base64.StdEncoding.DecodeString(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("not valid base64: %v", err)
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(der)
|
||||
if err != nil {
|
||||
t.Fatalf("not valid PKIX: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(pub, ec.Public()) {
|
||||
t.Error("public key did not round-trip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyJWK_EC(t *testing.T) {
|
||||
ec, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jwk, err := PublicKeyJWK(ec.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if jwk["kty"] != "EC" || jwk["crv"] != "P-256" {
|
||||
t.Errorf("jwk = %v, want kty=EC crv=P-256", jwk)
|
||||
}
|
||||
if jwk["use"] != "sig" {
|
||||
t.Errorf("jwk use = %v, want sig", jwk["use"])
|
||||
}
|
||||
x, _ := jwk["x"].(string)
|
||||
xb, err := base64.RawURLEncoding.DecodeString(x)
|
||||
if err != nil || len(xb) != 32 {
|
||||
t.Errorf("x = %q (decoded %d bytes), want 32-byte base64url", x, len(xb))
|
||||
}
|
||||
if _, ok := jwk["y"].(string); !ok {
|
||||
t.Error("jwk missing y")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyJWK_RSA(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jwk, err := PublicKeyJWK(rsaKey.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if jwk["kty"] != "RSA" || jwk["n"] == "" || jwk["e"] == "" {
|
||||
t.Errorf("jwk = %v, want kty=RSA with n,e", jwk)
|
||||
}
|
||||
if jwk["use"] != "sig" {
|
||||
t.Errorf("jwk use = %v, want sig", jwk["use"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyJWK_UnsupportedCurve(t *testing.T) {
|
||||
ec384, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := PublicKeyJWK(ec384.Public()); err == nil {
|
||||
t.Error("P-384: expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestECDSADERToJOSE(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Iterate so we hit signatures whose r or s has its high bit set (ASN.1 pads
|
||||
// those with a leading 0x00) and whose scalars are short (need left-zero
|
||||
// padding) — verifying fixed-width conversion in both directions.
|
||||
for i := 0; i < 64; i++ {
|
||||
digest := sha256.Sum256([]byte{byte(i), byte(i >> 8), 'j', 'w', 't'})
|
||||
der, err := ecdsa.SignASN1(rand.Reader, key, digest[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jose, err := ecdsaDERToJOSE(der, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("iter %d: %v", i, err)
|
||||
}
|
||||
if len(jose) != 64 {
|
||||
t.Fatalf("iter %d: len(jose)=%d, want 64 (fixed-width r||s)", i, len(jose))
|
||||
}
|
||||
r := new(big.Int).SetBytes(jose[:32])
|
||||
s := new(big.Int).SetBytes(jose[32:])
|
||||
if !ecdsa.Verify(&key.PublicKey, digest[:], r, s) {
|
||||
t.Fatalf("iter %d: converted r||s did not verify against the public key", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestECDSADERToJOSE_Errors(t *testing.T) {
|
||||
if _, err := ecdsaDERToJOSE([]byte{0x01, 0x02, 0x03}, 32); err == nil {
|
||||
t.Error("garbage DER: expected error")
|
||||
}
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
digest := sha256.Sum256([]byte("trailing"))
|
||||
der, err := ecdsa.SignASN1(rand.Reader, key, digest[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ecdsaDERToJOSE(append(der, 0x00), 32); err == nil {
|
||||
t.Error("DER with trailing byte: expected error")
|
||||
}
|
||||
}
|
||||
|
||||
type stubSigner struct{}
|
||||
|
||||
func (stubSigner) EnsureKey(context.Context, KeyRef) (crypto.PublicKey, error) { return nil, nil }
|
||||
func (stubSigner) PublicKey(context.Context, KeyRef) (crypto.PublicKey, error) { return nil, nil }
|
||||
func (stubSigner) Sign(context.Context, KeyRef, []byte) ([]byte, string, error) { return nil, "", nil }
|
||||
|
||||
func TestCleanProbeError(t *testing.T) {
|
||||
cause := errors.New("open /dev/tpmrm0: permission denied")
|
||||
const p = "sks: error fetching Secure Hardware Vendor Data: "
|
||||
|
||||
// sks double-wraps with the same %w prefix → collapse to a single prefix.
|
||||
doubled := fmt.Errorf(p+"%w", fmt.Errorf(p+"%w", cause))
|
||||
if got, want := cleanProbeError(doubled), p+cause.Error(); got != want {
|
||||
t.Errorf("doubled: got %q, want %q", got, want)
|
||||
}
|
||||
// Triple wrap collapses too.
|
||||
if got, want := cleanProbeError(fmt.Errorf(p+"%w", doubled)), p+cause.Error(); got != want {
|
||||
t.Errorf("tripled: got %q, want %q", got, want)
|
||||
}
|
||||
// A layer adding genuinely new context is preserved.
|
||||
if got, want := cleanProbeError(fmt.Errorf("load: %w", cause)), "load: "+cause.Error(); got != want {
|
||||
t.Errorf("distinct prefix: got %q, want %q", got, want)
|
||||
}
|
||||
// nil and unwrapped-leaf cases.
|
||||
if got := cleanProbeError(nil); got != "" {
|
||||
t.Errorf("nil: got %q, want empty", got)
|
||||
}
|
||||
if got := cleanProbeError(cause); got != cause.Error() {
|
||||
t.Errorf("leaf: got %q, want %q", got, cause.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type proberSigner struct {
|
||||
stubSigner
|
||||
info HardwareInfo
|
||||
}
|
||||
|
||||
func (p proberSigner) ProbeHardware(context.Context) (HardwareInfo, error) { return p.info, nil }
|
||||
|
||||
func TestProbeHardware(t *testing.T) {
|
||||
// nil signer and a signer that does not implement HardwareProber both yield ok=false.
|
||||
if _, ok, _ := probeHardware(context.Background(), nil); ok {
|
||||
t.Error("nil signer: ok should be false")
|
||||
}
|
||||
if _, ok, _ := probeHardware(context.Background(), stubSigner{}); ok {
|
||||
t.Error("non-prober signer: ok should be false")
|
||||
}
|
||||
|
||||
want := HardwareInfo{Backend: "tpm2", Available: true, VendorName: "ACME"}
|
||||
info, ok, err := probeHardware(context.Background(), proberSigner{info: want})
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("prober: ok=%v err=%v, want true/nil", ok, err)
|
||||
}
|
||||
if info != want {
|
||||
t.Errorf("info = %+v, want %+v", info, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry(t *testing.T) {
|
||||
if Active() != nil {
|
||||
t.Skip("a signer is already registered in this build")
|
||||
}
|
||||
Register(stubSigner{})
|
||||
if _, ok := Active().(stubSigner); !ok {
|
||||
t.Error("Active did not return the registered signer")
|
||||
}
|
||||
}
|
||||
29
internal/keysigner/registry.go
Normal file
29
internal/keysigner/registry.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package keysigner
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
active Signer
|
||||
)
|
||||
|
||||
// Register sets the active Signer. It is typically called from the init() of a
|
||||
// build-tagged or extension package that provides the platform TEE/Keychain
|
||||
// implementation. The last registration wins (one backend per platform).
|
||||
func Register(s Signer) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
active = s
|
||||
}
|
||||
|
||||
// Active returns the registered Signer, or nil if none is available — in which
|
||||
// case private_key_jwt is unsupported on this build and only client_secret auth
|
||||
// can be used.
|
||||
func Active() Signer {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return active
|
||||
}
|
||||
613
internal/keysigner/signer_keychain_darwin.go
Normal file
613
internal/keysigner/signer_keychain_darwin.go
Normal file
@@ -0,0 +1,613 @@
|
||||
//go:build darwin
|
||||
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// macOS non-exportable Keychain signer (compiled into every darwin build).
|
||||
//
|
||||
// It does NOT use the Secure Enclave / hardware TEE (which would require
|
||||
// code-signing entitlements that are unfriendly to open source). Instead it
|
||||
// generates an RSA-2048 key in software, imports it into a dedicated app
|
||||
// keychain as NON-EXTRACTABLE (`security import -x`), then deletes the software
|
||||
// copy — so the private key can sign but can never be exported. Signing is
|
||||
// RSASSA-PKCS1v15-SHA256 (RS256).
|
||||
//
|
||||
// Unlike the original revision, this implementation calls the Security and
|
||||
// CoreFoundation frameworks via RUNTIME FFI (github.com/ebitengine/purego)
|
||||
// instead of cgo. The security model is identical (the private key is still a
|
||||
// non-extractable keychain key and every signature is produced by the OS via
|
||||
// SecKeyCreateSignature), but the binary builds with CGO_ENABLED=0 and can be
|
||||
// cross-compiled for darwin from any host — so release binaries no longer
|
||||
// require a native macOS build runner.
|
||||
//
|
||||
// Build with: go build (cgo-free; compiled into every darwin build, no tag)
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
)
|
||||
|
||||
// ---- Security / CoreFoundation runtime bindings (purego, no cgo) ----
|
||||
|
||||
const (
|
||||
cfFrameworkPath = "/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation"
|
||||
secFrameworkPath = "/System/Library/Frameworks/Security.framework/Security"
|
||||
|
||||
// kCFStringEncodingUTF8 (CFStringBuiltInEncodings).
|
||||
cfStringEncodingUTF8 = 0x08000100
|
||||
|
||||
// OSStatus values.
|
||||
errSecSuccess = 0
|
||||
)
|
||||
|
||||
var (
|
||||
ffiOnce sync.Once
|
||||
ffiErr error
|
||||
|
||||
cfDataCreate func(alloc uintptr, bytes *byte, length int) uintptr
|
||||
cfDataGetLength func(d uintptr) int
|
||||
cfDataGetBytePtr func(d uintptr) unsafe.Pointer
|
||||
cfStringCreate func(alloc uintptr, cstr *byte, encoding uint32) uintptr
|
||||
cfArrayCreate func(alloc uintptr, values *uintptr, numValues int, cb uintptr) uintptr
|
||||
cfDictCreateMutable func(alloc uintptr, capacity int, keyCB, valCB uintptr) uintptr
|
||||
cfDictSetValue func(dict, key, val uintptr)
|
||||
cfRelease func(ref uintptr)
|
||||
cfErrorGetCode func(e uintptr) int
|
||||
secKeychainOpen func(path *byte, out *uintptr) int32
|
||||
secItemCopyMatching func(query uintptr, result *uintptr) int32
|
||||
secItemUpdate func(query, attrs uintptr) int32
|
||||
secKeyCreateSignature func(key, algo, data uintptr, errOut *uintptr) uintptr
|
||||
|
||||
// CFTypeRef data-symbol constants (deref to obtain the held ref value).
|
||||
kSecClass uintptr
|
||||
kSecClassKey uintptr
|
||||
kSecAttrKeyClass uintptr
|
||||
kSecAttrKeyClassPrivate uintptr
|
||||
kSecAttrKeyType uintptr
|
||||
kSecAttrKeyTypeRSA uintptr
|
||||
kSecAttrApplicationLabel uintptr
|
||||
kSecReturnRef uintptr
|
||||
kSecMatchSearchList uintptr
|
||||
kSecAttrLabel uintptr
|
||||
kCFBooleanTrue uintptr
|
||||
algRSAPKCS1SHA256 uintptr
|
||||
|
||||
// Struct-symbol constants (passed BY ADDRESS, not dereferenced).
|
||||
cbTypeArray uintptr
|
||||
cbDictKey uintptr
|
||||
cbDictValue uintptr
|
||||
)
|
||||
|
||||
// loadFFI resolves the framework functions and constants once. Any failure
|
||||
// (framework missing, symbol absent) is returned to every caller so signing
|
||||
// fails cleanly rather than crashing.
|
||||
func loadFFI() error {
|
||||
ffiOnce.Do(func() {
|
||||
cf, err := purego.Dlopen(cfFrameworkPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
ffiErr = fmt.Errorf("keysigner: dlopen CoreFoundation: %w", err)
|
||||
return
|
||||
}
|
||||
sec, err := purego.Dlopen(secFrameworkPath, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
ffiErr = fmt.Errorf("keysigner: dlopen Security: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cfDataCreate, cf, "CFDataCreate")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfStringCreate, cf, "CFStringCreateWithCString")
|
||||
purego.RegisterLibFunc(&cfArrayCreate, cf, "CFArrayCreate")
|
||||
purego.RegisterLibFunc(&cfDictCreateMutable, cf, "CFDictionaryCreateMutable")
|
||||
purego.RegisterLibFunc(&cfDictSetValue, cf, "CFDictionarySetValue")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
purego.RegisterLibFunc(&cfErrorGetCode, cf, "CFErrorGetCode")
|
||||
purego.RegisterLibFunc(&secKeychainOpen, sec, "SecKeychainOpen")
|
||||
purego.RegisterLibFunc(&secItemCopyMatching, sec, "SecItemCopyMatching")
|
||||
purego.RegisterLibFunc(&secItemUpdate, sec, "SecItemUpdate")
|
||||
purego.RegisterLibFunc(&secKeyCreateSignature, sec, "SecKeyCreateSignature")
|
||||
|
||||
// CFStringRef/CFBooleanRef constants: Dlsym gives the address of the
|
||||
// exported variable; deref once to read the ref it holds.
|
||||
derefs := []struct {
|
||||
dst *uintptr
|
||||
handle uintptr
|
||||
name string
|
||||
}{
|
||||
{&kSecClass, sec, "kSecClass"},
|
||||
{&kSecClassKey, sec, "kSecClassKey"},
|
||||
{&kSecAttrKeyClass, sec, "kSecAttrKeyClass"},
|
||||
{&kSecAttrKeyClassPrivate, sec, "kSecAttrKeyClassPrivate"},
|
||||
{&kSecAttrKeyType, sec, "kSecAttrKeyType"},
|
||||
{&kSecAttrKeyTypeRSA, sec, "kSecAttrKeyTypeRSA"},
|
||||
{&kSecAttrApplicationLabel, sec, "kSecAttrApplicationLabel"},
|
||||
{&kSecReturnRef, sec, "kSecReturnRef"},
|
||||
{&kSecMatchSearchList, sec, "kSecMatchSearchList"},
|
||||
{&kSecAttrLabel, sec, "kSecAttrLabel"},
|
||||
{&kCFBooleanTrue, cf, "kCFBooleanTrue"},
|
||||
{&algRSAPKCS1SHA256, sec, "kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA256"},
|
||||
}
|
||||
for _, d := range derefs {
|
||||
sym, e := purego.Dlsym(d.handle, d.name)
|
||||
if e != nil || sym == 0 {
|
||||
ffiErr = fmt.Errorf("keysigner: dlsym %s: %v", d.name, e)
|
||||
return
|
||||
}
|
||||
// deref of a stable dylib data-symbol address (not Go-managed memory), so safe.
|
||||
*d.dst = *(*uintptr)(unsafe.Pointer(sym)) //nolint:govet // unsafeptr: see comment above
|
||||
}
|
||||
|
||||
// Callback structs are passed by address (no deref).
|
||||
addrs := []struct {
|
||||
dst *uintptr
|
||||
handle uintptr
|
||||
name string
|
||||
}{
|
||||
{&cbTypeArray, cf, "kCFTypeArrayCallBacks"},
|
||||
{&cbDictKey, cf, "kCFTypeDictionaryKeyCallBacks"},
|
||||
{&cbDictValue, cf, "kCFTypeDictionaryValueCallBacks"},
|
||||
}
|
||||
for _, a := range addrs {
|
||||
sym, e := purego.Dlsym(a.handle, a.name)
|
||||
if e != nil || sym == 0 {
|
||||
ffiErr = fmt.Errorf("keysigner: dlsym %s: %v", a.name, e)
|
||||
return
|
||||
}
|
||||
*a.dst = sym
|
||||
}
|
||||
})
|
||||
return ffiErr
|
||||
}
|
||||
|
||||
// cstr returns a pointer to a NUL-terminated copy of s. The backing array stays
|
||||
// alive while the returned pointer is reachable.
|
||||
func cstr(s string) *byte {
|
||||
b := append([]byte(s), 0)
|
||||
return &b[0]
|
||||
}
|
||||
|
||||
// cfBytes wraps Go bytes in a CFData (CFDataCreate copies the bytes). Caller
|
||||
// releases the returned CFDataRef.
|
||||
func cfBytes(b []byte) uintptr {
|
||||
var p *byte
|
||||
if len(b) > 0 {
|
||||
p = &b[0]
|
||||
}
|
||||
d := cfDataCreate(0, p, len(b))
|
||||
runtime.KeepAlive(b)
|
||||
return d
|
||||
}
|
||||
|
||||
// keychainSearchArray opens the dedicated keychain file and wraps it in a
|
||||
// CFArray for kSecMatchSearchList. Caller releases the returned array.
|
||||
//
|
||||
// NOTE: SecKeychainOpen / the file-based keychain are deprecated by Apple in
|
||||
// favor of the data-protection keychain. They still function on current macOS;
|
||||
// migrating off them is tracked separately and is independent of the cgo→purego
|
||||
// change (the original cgo version used the same APIs).
|
||||
func keychainSearchArray(keychainPath string) (uintptr, error) {
|
||||
var kc uintptr
|
||||
if st := secKeychainOpen(cstr(keychainPath), &kc); st != errSecSuccess {
|
||||
return 0, keychainError("open keychain", int(st))
|
||||
}
|
||||
vals := [1]uintptr{kc}
|
||||
arr := cfArrayCreate(0, &vals[0], 1, cbTypeArray)
|
||||
cfRelease(kc) // the array retains it
|
||||
if arr == 0 {
|
||||
return 0, fmt.Errorf("keysigner: CFArrayCreate(search list) failed")
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// findPrivateKey locates the non-extractable private key by its application
|
||||
// label within the dedicated keychain. Caller releases the returned SecKeyRef.
|
||||
func findPrivateKey(appLabel []byte, keychainPath string) (uintptr, error) {
|
||||
search, err := keychainSearchArray(keychainPath)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cfRelease(search)
|
||||
|
||||
labelData := cfBytes(appLabel)
|
||||
defer cfRelease(labelData)
|
||||
|
||||
q := cfDictCreateMutable(0, 0, cbDictKey, cbDictValue)
|
||||
if q == 0 {
|
||||
return 0, fmt.Errorf("keysigner: CFDictionaryCreateMutable(query) failed")
|
||||
}
|
||||
defer cfRelease(q)
|
||||
cfDictSetValue(q, kSecClass, kSecClassKey)
|
||||
cfDictSetValue(q, kSecAttrKeyClass, kSecAttrKeyClassPrivate)
|
||||
cfDictSetValue(q, kSecAttrKeyType, kSecAttrKeyTypeRSA)
|
||||
cfDictSetValue(q, kSecAttrApplicationLabel, labelData)
|
||||
cfDictSetValue(q, kSecReturnRef, kCFBooleanTrue)
|
||||
cfDictSetValue(q, kSecMatchSearchList, search)
|
||||
|
||||
var keyRef uintptr
|
||||
if st := secItemCopyMatching(q, &keyRef); st != errSecSuccess {
|
||||
return 0, keychainError("find private key", int(st))
|
||||
}
|
||||
return keyRef, nil
|
||||
}
|
||||
|
||||
// securityBin is invoked by absolute path so a poisoned PATH cannot hijack it.
|
||||
const securityBin = "/usr/bin/security"
|
||||
|
||||
// keychainSigner implements Signer using a macOS non-exportable Keychain key.
|
||||
type keychainSigner struct{}
|
||||
|
||||
func init() { Register(keychainSigner{}) }
|
||||
|
||||
// ProbeHardware reports the macOS Keychain backend backing this signer. The
|
||||
// keychain signer is compiled into every darwin build and needs no special
|
||||
// hardware, so it reports available whenever the Security tooling is present.
|
||||
// It performs no key access, so it never prompts. Implementing HardwareProber
|
||||
// is what lets `doctor` report the signer as present rather than treating the
|
||||
// (prober-less) signer as "no TEE signer in this build".
|
||||
func (keychainSigner) ProbeHardware(_ context.Context) (HardwareInfo, error) {
|
||||
info := HardwareInfo{Backend: "keychain", VendorName: "macOS Keychain"}
|
||||
// A missing security tool is a status (Available=false via Reason), not a
|
||||
// probe error — so we deliberately return a nil error here.
|
||||
if _, err := vfs.Stat(securityBin); err != nil {
|
||||
info.Reason = securityBin + " not found"
|
||||
return info, nil //nolint:nilerr // absence is reported via Reason, not as an error
|
||||
}
|
||||
info.Available = true
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (keychainSigner) EnsureKey(_ context.Context, ref KeyRef) (crypto.PublicKey, error) {
|
||||
if md, err := readKeyMetadata(ref.Label); err == nil {
|
||||
return decodePublicKey(md.PublicKey)
|
||||
} else if !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
return createKeychainKey(ref.Label)
|
||||
}
|
||||
|
||||
func (keychainSigner) PublicKey(_ context.Context, ref KeyRef) (crypto.PublicKey, error) {
|
||||
md, err := readKeyMetadata(ref.Label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodePublicKey(md.PublicKey)
|
||||
}
|
||||
|
||||
func (keychainSigner) Sign(_ context.Context, ref KeyRef, signingInput []byte) ([]byte, string, error) {
|
||||
if err := loadFFI(); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
md, err := readKeyMetadata(ref.Label)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
appLabel, err := hex.DecodeString(md.AppLabel)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("keysigner: decode app label: %w", err)
|
||||
}
|
||||
if len(appLabel) == 0 {
|
||||
// Guard the &appLabel[0] pointer below against corrupted metadata.
|
||||
return nil, "", fmt.Errorf("keysigner: key metadata for %q has empty app label", ref.Label)
|
||||
}
|
||||
keychain, err := ensureKeychain()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
keyRef, err := findPrivateKey(appLabel, keychain)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer cfRelease(keyRef)
|
||||
|
||||
digest := sha256.Sum256(signingInput)
|
||||
digestData := cfBytes(digest[:])
|
||||
defer cfRelease(digestData)
|
||||
|
||||
var errRef uintptr
|
||||
sigRef := secKeyCreateSignature(keyRef, algRSAPKCS1SHA256, digestData, &errRef)
|
||||
if sigRef == 0 {
|
||||
code := 0
|
||||
if errRef != 0 {
|
||||
code = cfErrorGetCode(errRef)
|
||||
cfRelease(errRef)
|
||||
}
|
||||
return nil, "", fmt.Errorf("keysigner: SecKeyCreateSignature failed (CFError %d)", code)
|
||||
}
|
||||
defer cfRelease(sigRef)
|
||||
|
||||
n := cfDataGetLength(sigRef)
|
||||
bp := cfDataGetBytePtr(sigRef)
|
||||
out := make([]byte, n)
|
||||
copy(out, unsafe.Slice((*byte)(bp), n))
|
||||
// RS256: the SecKey PKCS1v15-SHA256 signature is the JOSE signature as-is.
|
||||
return out, AlgRS256, nil
|
||||
}
|
||||
|
||||
// keyMetadata records the public key + the keychain application-label used to
|
||||
// locate the non-extractable private key.
|
||||
type keyMetadata struct {
|
||||
PublicKey string `json:"public_key"` // PKIX DER, std base64 (see EncodePublicKey)
|
||||
AppLabel string `json:"app_label"` // hex(sha1(PKCS1 public key))
|
||||
}
|
||||
|
||||
func createKeychainKey(label string) (crypto.PublicKey, error) {
|
||||
metadataPath, err := keyMetadataPath(label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: generate RSA key: %w", err)
|
||||
}
|
||||
appLabel := sha1.Sum(x509.MarshalPKCS1PublicKey(&privateKey.PublicKey))
|
||||
|
||||
pemFile, err := vfs.CreateTemp("", "lark-keysigner-*.pem")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: temp key file: %w", err)
|
||||
}
|
||||
pemPath := pemFile.Name()
|
||||
defer vfs.Remove(pemPath)
|
||||
if err := pemFile.Chmod(0600); err != nil {
|
||||
pemFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
if _, err := pemFile.WriteString("-----BEGIN RSA PRIVATE KEY-----\n" +
|
||||
base64Wrap(der) + "-----END RSA PRIVATE KEY-----\n"); err != nil {
|
||||
pemFile.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := pemFile.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
executable, err := vfs.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: resolve executable: %w", err)
|
||||
}
|
||||
keychain, err := ensureKeychain()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// -x: import as NON-EXTRACTABLE; the software copy (pemPath) is then removed.
|
||||
importCmd := exec.Command(securityBin, "import", pemPath, "-k", keychain, "-t", "priv", "-f", "openssl", "-x", "-A", "-T", executable)
|
||||
if out, err := importCmd.CombinedOutput(); err != nil {
|
||||
return nil, fmt.Errorf("keysigner: import non-extractable key: %w: %s", err, summarizeCmdOutput(out))
|
||||
}
|
||||
if err := setKeychainKeyLabel(appLabel[:], keychain, label); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encodedPub, err := EncodePublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := writeKeyMetadata(metadataPath, keyMetadata{PublicKey: encodedPub, AppLabel: hex.EncodeToString(appLabel[:])}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &privateKey.PublicKey, nil
|
||||
}
|
||||
|
||||
func setKeychainKeyLabel(appLabel []byte, keychain, label string) error {
|
||||
if err := loadFFI(); err != nil {
|
||||
return err
|
||||
}
|
||||
search, err := keychainSearchArray(keychain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cfRelease(search)
|
||||
|
||||
labelData := cfBytes(appLabel)
|
||||
defer cfRelease(labelData)
|
||||
|
||||
q := cfDictCreateMutable(0, 0, cbDictKey, cbDictValue)
|
||||
if q == 0 {
|
||||
return fmt.Errorf("keysigner: CFDictionaryCreateMutable(query) failed")
|
||||
}
|
||||
defer cfRelease(q)
|
||||
cfDictSetValue(q, kSecClass, kSecClassKey)
|
||||
cfDictSetValue(q, kSecAttrKeyClass, kSecAttrKeyClassPrivate)
|
||||
cfDictSetValue(q, kSecAttrKeyType, kSecAttrKeyTypeRSA)
|
||||
cfDictSetValue(q, kSecAttrApplicationLabel, labelData)
|
||||
cfDictSetValue(q, kSecMatchSearchList, search)
|
||||
|
||||
cfLabel := cfStringCreate(0, cstr(label), cfStringEncodingUTF8)
|
||||
if cfLabel == 0 {
|
||||
return fmt.Errorf("keysigner: CFStringCreateWithCString failed")
|
||||
}
|
||||
defer cfRelease(cfLabel)
|
||||
attrs := cfDictCreateMutable(0, 0, cbDictKey, cbDictValue)
|
||||
if attrs == 0 {
|
||||
return fmt.Errorf("keysigner: CFDictionaryCreateMutable(attrs) failed")
|
||||
}
|
||||
defer cfRelease(attrs)
|
||||
cfDictSetValue(attrs, kSecAttrLabel, cfLabel)
|
||||
|
||||
if st := secItemUpdate(q, attrs); st != errSecSuccess {
|
||||
return keychainError("set keychain key label", int(st))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodePublicKey(encoded string) (crypto.PublicKey, error) {
|
||||
der, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: decode public key: %w", err)
|
||||
}
|
||||
return x509.ParsePKIXPublicKey(der)
|
||||
}
|
||||
|
||||
// base64Wrap PEM-wraps DER bytes at 64 columns.
|
||||
func base64Wrap(der []byte) string {
|
||||
enc := base64.StdEncoding.EncodeToString(der)
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(enc); i += 64 {
|
||||
end := i + 64
|
||||
if end > len(enc) {
|
||||
end = len(enc)
|
||||
}
|
||||
b.WriteString(enc[i:end])
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func readKeyMetadata(label string) (*keyMetadata, error) {
|
||||
path, err := keyMetadataPath(label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := vfs.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err // preserves os.ErrNotExist for EnsureKey
|
||||
}
|
||||
var md keyMetadata
|
||||
if err := json.Unmarshal(data, &md); err != nil {
|
||||
return nil, fmt.Errorf("keysigner: parse key metadata: %w", err)
|
||||
}
|
||||
return &md, nil
|
||||
}
|
||||
|
||||
func writeKeyMetadata(path string, md keyMetadata) error {
|
||||
if err := vfs.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(md, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return vfs.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func ensureKeychain() (string, error) {
|
||||
keychainPath, err := keychainFilePath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
password, err := keychainPassword()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := vfs.Stat(keychainPath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("keysigner: stat keychain: %w", err)
|
||||
}
|
||||
if err := vfs.MkdirAll(filepath.Dir(keychainPath), 0700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, args := range [][]string{
|
||||
{"create-keychain", "-p", password, keychainPath},
|
||||
{"set-keychain-settings", keychainPath},
|
||||
{"unlock-keychain", "-p", password, keychainPath},
|
||||
} {
|
||||
if out, err := exec.Command(securityBin, args...).CombinedOutput(); err != nil {
|
||||
return "", fmt.Errorf("keysigner: security %s: %w: %s", args[0], err, summarizeCmdOutput(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
return keychainPath, nil
|
||||
}
|
||||
|
||||
func keysignerDir() (string, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("keysigner: resolve config dir: %w", err)
|
||||
}
|
||||
return filepath.Join(configDir, "lark-cli", "keysigner"), nil
|
||||
}
|
||||
|
||||
func keychainFilePath() (string, error) {
|
||||
dir, err := keysignerDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(dir, "lark-cli.keychain"), nil
|
||||
}
|
||||
|
||||
func keychainPassword() (string, error) {
|
||||
dir, err := keysignerDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
path := filepath.Join(dir, "keychain.pass")
|
||||
if data, err := vfs.ReadFile(path); err == nil {
|
||||
if pw := strings.TrimSpace(string(data)); pw != "" {
|
||||
return pw, nil
|
||||
}
|
||||
return "", fmt.Errorf("keysigner: empty keychain password")
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
pw := hex.EncodeToString(buf)
|
||||
if err := vfs.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := vfs.WriteFile(path, []byte(pw+"\n"), 0600); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pw, nil
|
||||
}
|
||||
|
||||
func keyMetadataPath(label string) (string, error) {
|
||||
dir, err := keysignerDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := sha256.Sum256([]byte(label))
|
||||
return filepath.Join(dir, "keys", hex.EncodeToString(id[:])+".json"), nil
|
||||
}
|
||||
|
||||
// summarizeCmdOutput bounds external command output before it is embedded in
|
||||
// an error: first line only, capped at 200 chars.
|
||||
func summarizeCmdOutput(out []byte) string {
|
||||
s := strings.TrimSpace(string(out))
|
||||
if i := strings.IndexByte(s, '\n'); i >= 0 {
|
||||
s = strings.TrimSpace(s[:i])
|
||||
}
|
||||
const maxLen = 200
|
||||
if len(s) > maxLen {
|
||||
s = s[:maxLen] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func keychainError(operation string, status int) error {
|
||||
switch status {
|
||||
case -25299:
|
||||
return fmt.Errorf("keysigner: %s: key already exists", operation)
|
||||
case -25300:
|
||||
return fmt.Errorf("keysigner: %s: key not found", operation)
|
||||
case -2:
|
||||
return fmt.Errorf("keysigner: %s: allocation failed", operation)
|
||||
default:
|
||||
return fmt.Errorf("keysigner: %s: Security framework status %d", operation, status)
|
||||
}
|
||||
}
|
||||
62
internal/keysigner/signer_keychain_darwin_test.go
Normal file
62
internal/keysigner/signer_keychain_darwin_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build darwin
|
||||
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestKeychainSignerRegistered confirms the keychain_signer build self-registers
|
||||
// (init → Register), so keysigner.Active() is non-nil. No keychain access.
|
||||
func TestKeychainSignerRegistered(t *testing.T) {
|
||||
if _, ok := Active().(keychainSigner); !ok {
|
||||
t.Fatalf("Active() = %T, want keychainSigner (keychain_signer build must self-register)", Active())
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeychainSignerRoundTrip creates a real non-extractable RSA key, signs, and
|
||||
// verifies RS256 against the returned public key. Gated by LARK_KEYCHAIN_IT
|
||||
// because it mutates the dedicated lark-cli keychain store. The signer is now
|
||||
// cgo-free (purego runtime FFI), so it runs with CGO_ENABLED=0. Run with:
|
||||
//
|
||||
// LARK_KEYCHAIN_IT=1 go test -run RoundTrip ./internal/keysigner/
|
||||
func TestKeychainSignerRoundTrip(t *testing.T) {
|
||||
if os.Getenv("LARK_KEYCHAIN_IT") == "" {
|
||||
t.Skip("set LARK_KEYCHAIN_IT=1 to run (mutates the macOS keychain)")
|
||||
}
|
||||
s := keychainSigner{}
|
||||
ref := KeyRef{Label: "lark-cli-keychain-it"}
|
||||
|
||||
pub, err := s.EnsureKey(context.Background(), ref)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureKey: %v", err)
|
||||
}
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
t.Fatalf("public key = %T, want *rsa.PublicKey", pub)
|
||||
}
|
||||
if alg, err := AlgForKey(pub); err != nil || alg != AlgRS256 {
|
||||
t.Fatalf("AlgForKey = %q, %v; want RS256", alg, err)
|
||||
}
|
||||
|
||||
input := []byte("header.payload")
|
||||
sig, alg, err := s.Sign(context.Background(), ref, input)
|
||||
if err != nil {
|
||||
t.Fatalf("Sign: %v", err)
|
||||
}
|
||||
if alg != AlgRS256 {
|
||||
t.Errorf("Sign alg = %q, want RS256", alg)
|
||||
}
|
||||
h := sha256.Sum256(input)
|
||||
if err := rsa.VerifyPKCS1v15(rsaPub, crypto.SHA256, h[:], sig); err != nil {
|
||||
t.Errorf("RS256 signature did not verify: %v", err)
|
||||
}
|
||||
}
|
||||
135
internal/keysigner/signer_sks.go
Normal file
135
internal/keysigner/signer_sks.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//go:build linux || (windows && amd64)
|
||||
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// TPM 2.0 signer (compiled into every linux and windows/amd64 build, no build
|
||||
// tag required), backed by github.com/facebookincubator/sks.
|
||||
//
|
||||
// sks holds a non-exportable ECDSA P-256 key in the platform TPM and signs
|
||||
// SHA-256 digests. On Linux it talks to /dev/tpmrm0; on Windows it uses the
|
||||
// Microsoft Platform Crypto Provider (CNG). Both backends return an ASN.1 DER
|
||||
// ECDSA signature, which we convert to the fixed-width r||s form JWS requires for
|
||||
// ES256 (see ecdsaDERToJOSE). One key is created on the first private_key_jwt
|
||||
// registration (DefaultKeyLabel) and reused for subsequent app registrations and
|
||||
// every client_assertion on the same device.
|
||||
//
|
||||
// Excluded from windows/arm64: the sks Windows dependency stack (go-ole) has no
|
||||
// arm64 VARIANT and fails to compile, so windows/arm64 falls back to
|
||||
// client_secret only (keysigner.Active() is nil). On darwin the keychain signer
|
||||
// is used instead. CGO is never required.
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/facebookincubator/flog"
|
||||
"github.com/facebookincubator/sks"
|
||||
)
|
||||
|
||||
// p256ByteLen is the P-256 coordinate width. sks regular keys are always ECDSA
|
||||
// P-256, so ES256 signatures are 2*p256ByteLen bytes of r||s.
|
||||
const p256ByteLen = 32
|
||||
|
||||
// keyTag is the sks key tag. Both the Linux and Windows sks backends address
|
||||
// keys by label and ignore the tag, but the macOS backend uses it, so we set a
|
||||
// stable namespaced value for forward compatibility.
|
||||
const keyTag = "com.larksuite.cli"
|
||||
|
||||
// sksSigner implements Signer (and HardwareProber) using a non-exportable
|
||||
// TPM 2.0 ECDSA key via sks.
|
||||
type sksSigner struct{}
|
||||
|
||||
func init() {
|
||||
Register(sksSigner{})
|
||||
// This sks version logs verbose TPM-operation chatter to stderr via flog (a
|
||||
// glog fork it owns exclusively) — e.g. "Loaded TPM device", "Found handle
|
||||
// for key" on every sign. The CLI does not use flog, so silence it
|
||||
// process-wide here; real failures are returned as errors, never relied upon
|
||||
// from these logs. (Newer sks switched to slog, but that lands only on its
|
||||
// go-1.24 line, which we avoid to keep the module on go 1.23.)
|
||||
flog.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
// EnsureKey returns the public key for ref, creating the TPM key if absent.
|
||||
// sks.NewKey is find-or-create: it returns the existing key when one is present.
|
||||
func (sksSigner) EnsureKey(_ context.Context, ref KeyRef) (crypto.PublicKey, error) {
|
||||
key, err := sks.NewKey(ref.Label, keyTag, false, true, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("keysigner: ensure TPM key %q: %w", ref.Label, err)
|
||||
}
|
||||
defer key.Close()
|
||||
return ecdsaPublic(ref.Label, key.Public())
|
||||
}
|
||||
|
||||
// PublicKey returns the public key for ref without creating it. FromLabelTag does
|
||||
// not touch the TPM until Public() loads the sealed key; a missing key yields a
|
||||
// nil public key, which we surface as an error — at runtime the key MUST already
|
||||
// exist (it was bound to the app at registration), so we never silently mint a
|
||||
// new, unbound one here.
|
||||
func (sksSigner) PublicKey(_ context.Context, ref KeyRef) (crypto.PublicKey, error) {
|
||||
pub := sks.FromLabelTag(ref.Label).Public()
|
||||
if pub == nil {
|
||||
return nil, fmt.Errorf("keysigner: TPM key %q not found", ref.Label)
|
||||
}
|
||||
return ecdsaPublic(ref.Label, pub)
|
||||
}
|
||||
|
||||
// Sign signs signingInput with the TPM key and returns a JOSE-format ES256
|
||||
// signature (fixed-width r||s) plus its alg.
|
||||
func (sksSigner) Sign(_ context.Context, ref KeyRef, signingInput []byte) ([]byte, string, error) {
|
||||
key, err := sks.NewKey(ref.Label, keyTag, false, true, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("keysigner: load TPM key %q: %w", ref.Label, err)
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
// ES256 signs the SHA-256 digest of the JWS signing input.
|
||||
digest := sha256.Sum256(signingInput)
|
||||
der, err := key.Sign(nil, digest[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("keysigner: TPM sign with key %q: %w", ref.Label, err)
|
||||
}
|
||||
// Both sks backends emit ASN.1 DER; JWS ES256 requires fixed-width r||s
|
||||
// (RFC 7518 §3.4).
|
||||
rs, err := ecdsaDERToJOSE(der, p256ByteLen)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return rs, AlgES256, nil
|
||||
}
|
||||
|
||||
// ProbeHardware reports on the TPM backing this signer without touching any key.
|
||||
// A failure to reach the TPM (no device, permission denied, not TPM 2.0) is
|
||||
// reported as Available=false with Reason set, NOT as a Go error — the probe
|
||||
// still succeeded in determining that the TEE is currently unusable.
|
||||
func (sksSigner) ProbeHardware(_ context.Context) (HardwareInfo, error) {
|
||||
info := HardwareInfo{Backend: "tpm2"}
|
||||
data, err := sks.GetSecureHardwareVendorData()
|
||||
if err != nil {
|
||||
info.Reason = cleanProbeError(err)
|
||||
return info, nil
|
||||
}
|
||||
info.VendorName = data.VendorName
|
||||
info.VendorInfo = data.VendorInfo
|
||||
info.Available = data.IsTPM20CompliantDevice
|
||||
if !info.Available {
|
||||
info.Reason = "secure hardware is not a TPM 2.0 compliant device"
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// ecdsaPublic asserts that an sks public key is an ECDSA key (it always is for
|
||||
// regular sks keys) so the caller gets the concrete type AlgForKey/PublicKeyJWK expect.
|
||||
func ecdsaPublic(label string, pub crypto.PublicKey) (*ecdsa.PublicKey, error) {
|
||||
ecPub, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("keysigner: TPM key %q public is %T, want *ecdsa.PublicKey", label, pub)
|
||||
}
|
||||
return ecPub, nil
|
||||
}
|
||||
122
internal/keysigner/signer_sks_test.go
Normal file
122
internal/keysigner/signer_sks_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build linux || (windows && amd64)
|
||||
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package keysigner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/facebookincubator/flog"
|
||||
"github.com/facebookincubator/sks"
|
||||
)
|
||||
|
||||
// TestFlogSilenced verifies the mechanism init() relies on to keep sks's flog
|
||||
// TPM chatter off the CLI's stderr: SetOutput redirects flog, and io.Discard
|
||||
// drops it. Cleanup restores io.Discard so init()'s silencing holds for the
|
||||
// rest of the package's tests.
|
||||
func TestFlogSilenced(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
flog.SetOutput(&buf)
|
||||
t.Cleanup(func() { flog.SetOutput(io.Discard) })
|
||||
|
||||
flog.Info("captured-line")
|
||||
if !strings.Contains(buf.String(), "captured-line") {
|
||||
t.Fatalf("flog.SetOutput(buffer) did not capture output: %q", buf.String())
|
||||
}
|
||||
|
||||
flog.SetOutput(io.Discard)
|
||||
buf.Reset()
|
||||
flog.Info("should-be-discarded")
|
||||
if buf.Len() != 0 {
|
||||
t.Errorf("flog output not discarded: %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
// requireTEE skips the test unless the TPM is present and usable. On a Linux
|
||||
// machine with a TPM but a restrictive device owner (`/dev/tpmrm0` is `tss:tss`
|
||||
// by default), grant access with `sudo usermod -aG tss $USER` then re-login, or
|
||||
// run the test under sudo.
|
||||
func requireTEE(t *testing.T) {
|
||||
t.Helper()
|
||||
info, err := sksSigner{}.ProbeHardware(context.Background())
|
||||
if err != nil || !info.Available {
|
||||
reason := info.Reason
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
}
|
||||
t.Skipf("TEE not available (%s)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSKSSignerRoundTrip exercises the full registration→assertion contract
|
||||
// against the real TPM: create the key, read it back without creating, derive
|
||||
// the JWS alg + JWK, sign, and verify the fixed-width r||s output.
|
||||
func TestSKSSignerRoundTrip(t *testing.T) {
|
||||
requireTEE(t)
|
||||
|
||||
var s sksSigner
|
||||
ctx := context.Background()
|
||||
ref := KeyRef{Label: "larksuite-cli-test"}
|
||||
|
||||
// Best-effort cleanup so the test key does not linger in the TPM-sealed store.
|
||||
t.Cleanup(func() {
|
||||
if k, err := sks.NewKey(ref.Label, keyTag, false, true, nil); err == nil {
|
||||
_ = k.Remove()
|
||||
_ = k.Close()
|
||||
}
|
||||
})
|
||||
|
||||
pub, err := s.EnsureKey(ctx, ref)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureKey: %v", err)
|
||||
}
|
||||
ecPub, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
t.Fatalf("EnsureKey returned %T, want *ecdsa.PublicKey", pub)
|
||||
}
|
||||
|
||||
// PublicKey (no-create) must return the same key bound at EnsureKey.
|
||||
pub2, err := s.PublicKey(ctx, ref)
|
||||
if err != nil {
|
||||
t.Fatalf("PublicKey: %v", err)
|
||||
}
|
||||
if !ecPub.Equal(pub2) {
|
||||
t.Fatal("PublicKey returned a different key than EnsureKey")
|
||||
}
|
||||
|
||||
// The JWT layer derives alg + JWK from the public key; both must work.
|
||||
if alg, err := AlgForKey(pub); err != nil || alg != AlgES256 {
|
||||
t.Fatalf("AlgForKey = %q, %v; want ES256", alg, err)
|
||||
}
|
||||
if _, err := PublicKeyJWK(pub); err != nil {
|
||||
t.Fatalf("PublicKeyJWK: %v", err)
|
||||
}
|
||||
|
||||
// Sign a representative JWS signing input and verify the converted r||s.
|
||||
input := []byte("eyJhbGciOiJFUzI1NiJ9.eyJzdWIiOiJjbGkifQ")
|
||||
sig, alg, err := s.Sign(ctx, ref, input)
|
||||
if err != nil {
|
||||
t.Fatalf("Sign: %v", err)
|
||||
}
|
||||
if alg != AlgES256 {
|
||||
t.Fatalf("Sign alg = %q, want ES256", alg)
|
||||
}
|
||||
if len(sig) != 2*p256ByteLen {
|
||||
t.Fatalf("len(sig) = %d, want %d (fixed-width r||s)", len(sig), 2*p256ByteLen)
|
||||
}
|
||||
digest := sha256.Sum256(input)
|
||||
r := new(big.Int).SetBytes(sig[:p256ByteLen])
|
||||
ss := new(big.Int).SetBytes(sig[p256ByteLen:])
|
||||
if !ecdsa.Verify(ecPub, digest[:], r, ss) {
|
||||
t.Fatal("TPM signature did not verify against the public key")
|
||||
}
|
||||
}
|
||||
@@ -113,8 +113,7 @@ type EnumOption struct {
|
||||
}
|
||||
|
||||
// EnumOptions returns the field's allowed values paired with their descriptions
|
||||
// — from enum (with descriptions backfilled from options when the field carries
|
||||
// both forms), or from options when enum is absent — coerced to the canonical
|
||||
// — from enum, or from options when enum is absent — coerced to the canonical
|
||||
// type and ordered: numeric and boolean values are sorted; string values keep
|
||||
// source order (which can encode priority). Uncoercible literals are dropped.
|
||||
// Returns nil when the field declares no enum constraint.
|
||||
@@ -123,14 +122,9 @@ func (f Field) EnumOptions() []EnumOption {
|
||||
var out []EnumOption
|
||||
switch {
|
||||
case len(f.Enum) > 0:
|
||||
// key by raw literal so enum "1" and option 1 align across JSON types
|
||||
desc := make(map[string]string, len(f.Options))
|
||||
for _, o := range f.Options {
|
||||
desc[fmt.Sprintf("%v", o.Value)] = o.Description
|
||||
}
|
||||
for _, e := range f.Enum {
|
||||
if v, ok := coerceLiteral(ct, e); ok {
|
||||
out = append(out, EnumOption{Value: v, Description: desc[fmt.Sprintf("%v", e)]})
|
||||
out = append(out, EnumOption{Value: v})
|
||||
}
|
||||
}
|
||||
case len(f.Options) > 0:
|
||||
|
||||
@@ -80,39 +80,6 @@ func TestField_EnumOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestField_EnumOptions_BothEnumAndOptions(t *testing.T) {
|
||||
// enum is the value set; descriptions backfilled from options, empty where absent
|
||||
f := Field{Type: "string", Enum: []any{"1", "2", "3", "4", "6"}, Options: []Option{
|
||||
{Value: "1", Description: "from"},
|
||||
{Value: "2", Description: "to"},
|
||||
{Value: "6", Description: "subject"},
|
||||
}}
|
||||
want := []EnumOption{
|
||||
{Value: "1", Description: "from"},
|
||||
{Value: "2", Description: "to"},
|
||||
{Value: "3", Description: ""},
|
||||
{Value: "4", Description: ""},
|
||||
{Value: "6", Description: "subject"},
|
||||
}
|
||||
if got := f.EnumOptions(); !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("EnumOptions(enum+options) = %+v, want %+v", got, want)
|
||||
}
|
||||
|
||||
// enum values stored as strings match option values stored as numbers
|
||||
fi := Field{Type: "integer", Enum: []any{"10", "2", "1"}, Options: []Option{
|
||||
{Value: 1, Description: "one"},
|
||||
{Value: 2, Description: "two"},
|
||||
}}
|
||||
wantI := []EnumOption{
|
||||
{Value: int64(1), Description: "one"},
|
||||
{Value: int64(2), Description: "two"},
|
||||
{Value: int64(10), Description: ""},
|
||||
}
|
||||
if got := fi.EnumOptions(); !reflect.DeepEqual(got, wantI) {
|
||||
t.Errorf("EnumOptions(integer enum+options) = %+v, want %+v", got, wantI)
|
||||
}
|
||||
}
|
||||
|
||||
func TestField_Enum_NumberAndBoolean(t *testing.T) {
|
||||
// number: string-stored floats coerced to float64 and numerically sorted
|
||||
if got := (Field{Type: "number", Enum: []any{"2.5", "1.5", "10"}}).EnumValues(); !reflect.DeepEqual(got, []any{1.5, 2.5, float64(10)}) {
|
||||
|
||||
@@ -472,18 +472,6 @@ func TestConvert_EnumDescriptions(t *testing.T) {
|
||||
if bare.EnumDescriptions != nil {
|
||||
t.Errorf("bare enum must have nil EnumDescriptions, got %v", bare.EnumDescriptions)
|
||||
}
|
||||
|
||||
// enum + options both present -> enumDescriptions backfilled, aligned, "" where absent
|
||||
both := Convert(meta.Field{Type: "string", Enum: []any{"1", "2", "3"}, Options: []meta.Option{
|
||||
{Value: "1", Description: "from"},
|
||||
{Value: "2", Description: "to"},
|
||||
}})
|
||||
if !reflect.DeepEqual(both.Enum, []interface{}{"1", "2", "3"}) {
|
||||
t.Errorf("both Enum = %v", both.Enum)
|
||||
}
|
||||
if !reflect.DeepEqual(both.EnumDescriptions, []string{"from", "to", ""}) {
|
||||
t.Errorf("both EnumDescriptions = %v, want [from to \"\"] aligned with enum", both.EnumDescriptions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMeta_AffordanceFromMethod(t *testing.T) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@larksuite/cli",
|
||||
"version": "1.0.57",
|
||||
"version": "1.0.56",
|
||||
"description": "The official CLI for Lark/Feishu open platform",
|
||||
"bin": {
|
||||
"lark-cli": "scripts/run.js"
|
||||
|
||||
@@ -24,6 +24,10 @@ build_target() {
|
||||
ext=".exe"
|
||||
fi
|
||||
|
||||
# The platform key signers are compiled in by build constraint, no tags:
|
||||
# darwin keychain (//go:build darwin) and linux/windows-amd64 TPM
|
||||
# (//go:build linux || (windows && amd64)). windows/arm64 arch-excludes the TPM
|
||||
# signer (go-ole has no arm64) and falls back to client_secret only.
|
||||
local output="$OUT_DIR/bin/lark-cli-${goos}-${goarch}${ext}"
|
||||
echo "Building ${goos}/${goarch} -> ${output}"
|
||||
CGO_ENABLED=0 GOOS="$goos" GOARCH="$goarch" go build -trimpath -ldflags "$LDFLAGS" -o "$output" ./main.go
|
||||
|
||||
@@ -179,10 +179,7 @@ fi
|
||||
require_in_step "$summary_verify_step" 'workflowPath !== ".github/workflows/ci.yml"' "PR quality summary must verify the triggering workflow path"
|
||||
require_in_step "$summary_verify_step" 'run.event !== "pull_request"' "PR quality summary must only handle pull_request workflow_run events"
|
||||
require_in_step "$summary_verify_step" 'run.repository.id !== context.payload.repository.id' "PR quality summary must verify workflow_run repository id"
|
||||
require_in_step "$summary_verify_step" 'const targetHeadSha = run.head_sha' "PR quality summary must use the CI run head SHA as the verified PR head"
|
||||
require_in_step "$summary_verify_step" 'eventHeadSha && eventHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()' "PR quality summary should tolerate mutable workflow_run PR head metadata"
|
||||
require_in_step "$summary_verify_step" 'factsArtifactPattern' "PR quality summary should use the base-bound facts artifact name when available"
|
||||
require_in_step "$summary_verify_step" 'const baseSha = artifactBaseSha || eventBaseSha || pr.base.sha' "PR quality summary must prefer the CI-time artifact base SHA"
|
||||
require_in_step "$summary_verify_step" 'core.setOutput("artifact_error"' "PR quality summary must expose artifact binding failures"
|
||||
require_in_step "$summary_artifact_step" 'factsArtifactName' "PR quality summary artifact step must use the verified facts artifact binding"
|
||||
require_in_step "$summary_extract_facts_step" 'SEMANTIC_REVIEW_DECISION_OUT' "PR quality summary artifact verifier must write an infrastructure decision on verifier failure"
|
||||
@@ -201,9 +198,8 @@ require_in_step "$verify_step" 'workflowPath !== ".github/workflows/ci.yml"' "se
|
||||
require_in_step "$verify_step" 'run.repository.id !== context.payload.repository.id' "semantic-review must verify workflow_run repository id"
|
||||
require_in_step "$verify_step" 'run.event !== "pull_request"' "semantic-review must only handle pull_request workflow_run events"
|
||||
require_in_step "$verify_step" 'run.conclusion !== "success"' "semantic-review must only consume successful CI runs"
|
||||
require_in_step "$verify_step" 'const eventHeadSha = runPRs[0]?.head?.sha || ""' "semantic-review should inspect workflow_run PR head metadata"
|
||||
require_in_step "$verify_step" 'const targetHeadSha = run.head_sha' "semantic-review target PR head must come from the completed CI run"
|
||||
require_in_step "$verify_step" 'eventHeadSha && eventHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()' "semantic-review should tolerate mutable workflow_run PR head metadata"
|
||||
require_in_step "$verify_step" 'const eventHeadSha = runPRs[0]?.head?.sha || ""' "semantic-review must prefer workflow_run PR head when GitHub provides it"
|
||||
require_in_step "$verify_step" 'const targetHeadSha = eventHeadSha || run.head_sha' "semantic-review target PR head must come from the workflow_run event"
|
||||
require_in_step "$verify_step" 'factsArtifactPattern' "semantic-review must use a base-bound facts artifact name"
|
||||
require_in_step "$verify_step" 'listWorkflowRunArtifacts' "semantic-review must read the workflow_run artifacts before resolving fallback base SHA"
|
||||
require_in_step "$verify_step" 'artifactHeadSha.toLowerCase() !== targetHeadSha.toLowerCase()' "semantic-review must not let the artifact choose a different PR head"
|
||||
@@ -214,8 +210,8 @@ require_in_step "$verify_step" 'commit_sha: targetHeadSha' "semantic-review fall
|
||||
require_in_step "$verify_step" 'github.rest.pulls.list' "semantic-review must have a pull-list fallback when commit association is empty"
|
||||
require_in_step "$verify_step" 'candidatePRs.length > 1' "semantic-review must fail closed when commit-to-PR fallback is ambiguous"
|
||||
require_in_step "$verify_step" 'pr.head.sha !== targetHeadSha' "semantic-review must skip stale PR heads"
|
||||
require_in_step "$verify_step" 'eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()' "semantic-review should tolerate mutable workflow_run PR base metadata"
|
||||
require_in_step "$verify_step" 'const baseSha = artifactBaseSha || eventBaseSha || pr.base.sha' "semantic-review must prefer the CI-time artifact base SHA"
|
||||
require_in_step "$verify_step" 'eventBaseSha && parsedBaseSha.toLowerCase() !== eventBaseSha.toLowerCase()' "semantic-review must reject mismatched event and artifact base SHAs"
|
||||
require_in_step "$verify_step" 'const baseSha = eventBaseSha || artifactBaseSha' "semantic-review fallback must use the CI-time artifact base SHA"
|
||||
require_in_step "$verify_step" 'pr.base.sha !== baseSha' "semantic-review must skip stale PR bases"
|
||||
require_in_step "$verify_step" 'core.setOutput("run_id"' "semantic-review must pass verified workflow run id to publisher"
|
||||
require_in_step "$verify_step" 'core.setOutput("head_repo_id"' "semantic-review must pass verified head repo id"
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
@@ -85,9 +84,6 @@ var AppsHTMLPublish = common.Shortcut{
|
||||
// for dry-run "advisory preview" semantics).
|
||||
dry.Set("validation_error", err.Error())
|
||||
}
|
||||
if hits := oversizeHTMLFiles(candidates); len(hits) > 0 {
|
||||
dry.Set("oversize_html", hits)
|
||||
}
|
||||
dry.Set("file_count", len(candidates))
|
||||
var totalSize int64
|
||||
names := make([]string, 0, len(candidates))
|
||||
@@ -144,22 +140,18 @@ type appsHTMLPublishSpec struct {
|
||||
// per-environment .env.* files for every stage).
|
||||
const maxSensitiveListInError = 5
|
||||
|
||||
// truncatedJoin joins items with ", ", capping at max entries and appending
|
||||
// "(and N more)" for the remainder, so an inline error list stays readable when
|
||||
// a payload has many hits.
|
||||
func truncatedJoin(items []string, max int) string {
|
||||
if len(items) <= max {
|
||||
return strings.Join(items, ", ")
|
||||
}
|
||||
return strings.Join(items[:max], ", ") + fmt.Sprintf(" (and %d more)", len(items)-max)
|
||||
}
|
||||
|
||||
// sensitiveCandidatesError builds the Validate-time rejection when --path
|
||||
// contains credential files and --allow-sensitive was not set.
|
||||
func sensitiveCandidatesError(hits []string) error {
|
||||
var sample string
|
||||
if len(hits) <= maxSensitiveListInError {
|
||||
sample = strings.Join(hits, ", ")
|
||||
} else {
|
||||
sample = strings.Join(hits[:maxSensitiveListInError], ", ") +
|
||||
fmt.Sprintf(" (and %d more)", len(hits)-maxSensitiveListInError)
|
||||
}
|
||||
return appsValidationParamError("--path",
|
||||
"--path contains %d credential file(s) that should not be published: %s",
|
||||
len(hits), truncatedJoin(hits, maxSensitiveListInError)).
|
||||
"--path contains %d credential file(s) that should not be published: %s", len(hits), sample).
|
||||
WithHint("remove these files from the publish payload, OR pass --allow-sensitive if shipping them is intentional (e.g. a docs site demoing credential-file formats)")
|
||||
}
|
||||
|
||||
@@ -176,30 +168,6 @@ var maxHTMLPublishTarballBytes int64 = 20 * 1024 * 1024
|
||||
// Mutable for tests.
|
||||
var maxHTMLPublishRawBytes int64 = 200 * 1024 * 1024
|
||||
|
||||
// maxHTMLPublishSingleHTMLFileBytes 单个 .html 文件上限,对齐妙搭服务端 10MB 约束。
|
||||
// 用 var 而非 const,便于单测调小覆盖拦截路径。
|
||||
var maxHTMLPublishSingleHTMLFileBytes int64 = 10 * 1024 * 1024
|
||||
|
||||
// oversizeHTMLFiles 返回 candidates 中扩展名为 .html(大小写不敏感)且单个 Size 超过
|
||||
// maxHTMLPublishSingleHTMLFileBytes 的 RelPath 列表。只针对 .html 文件,不波及图片/字体/JS。
|
||||
func oversizeHTMLFiles(candidates []htmlPublishCandidate) []string {
|
||||
var hits []string
|
||||
for _, c := range candidates {
|
||||
if strings.EqualFold(filepath.Ext(c.RelPath), ".html") && c.Size > maxHTMLPublishSingleHTMLFileBytes {
|
||||
hits = append(hits, c.RelPath)
|
||||
}
|
||||
}
|
||||
return hits
|
||||
}
|
||||
|
||||
// oversizeHTMLFilesError 构造单文件超限的 Validate 风格拒绝。
|
||||
func oversizeHTMLFilesError(hits []string) error {
|
||||
return appsValidationParamError("--path",
|
||||
"--path contains %d HTML file(s) exceeding the %d bytes (10MB) per-file limit: %s",
|
||||
len(hits), maxHTMLPublishSingleHTMLFileBytes, truncatedJoin(hits, maxSensitiveListInError)).
|
||||
WithHint("split or trim oversized HTML file(s); the 10MB cap applies to each single .html file")
|
||||
}
|
||||
|
||||
// ensureIndexHTML 要求 walker 抓到的 candidates 里必须含 index.html。
|
||||
// 目录形态:根目录下必须有 index.html。
|
||||
// 单文件形态:文件名必须就是 index.html。
|
||||
@@ -222,9 +190,6 @@ func runHTMLPublish(ctx context.Context, fio fileio.FileIO, publisher appsHTMLPu
|
||||
if err := ensureIndexHTML(candidates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hits := oversizeHTMLFiles(candidates); len(hits) > 0 {
|
||||
return nil, oversizeHTMLFilesError(hits)
|
||||
}
|
||||
var rawTotal int64
|
||||
for _, c := range candidates {
|
||||
rawTotal += c.Size
|
||||
|
||||
@@ -503,82 +503,3 @@ func TestRunHTMLPublish_RejectsOversizeRawCandidates(t *testing.T) {
|
||||
t.Fatalf("client must not be called when raw cap hit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOversizeHTMLFiles(t *testing.T) {
|
||||
orig := maxHTMLPublishSingleHTMLFileBytes
|
||||
maxHTMLPublishSingleHTMLFileBytes = 100
|
||||
defer func() { maxHTMLPublishSingleHTMLFileBytes = orig }()
|
||||
|
||||
cands := []htmlPublishCandidate{
|
||||
{RelPath: "index.html", Size: 50},
|
||||
{RelPath: "big.html", Size: 4096},
|
||||
{RelPath: "BIG.HTML", Size: 4096}, // 大小写不敏感
|
||||
{RelPath: "huge.png", Size: 9000}, // 非 .html,忽略
|
||||
}
|
||||
hits := oversizeHTMLFiles(cands)
|
||||
if len(hits) != 2 {
|
||||
t.Fatalf("hits=%v, want [big.html BIG.HTML]", hits)
|
||||
}
|
||||
for _, h := range hits {
|
||||
if h == "huge.png" || h == "index.html" {
|
||||
t.Fatalf("unexpected hit %q", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxHTMLPublishSingleHTMLFileBytes_Default(t *testing.T) {
|
||||
if maxHTMLPublishSingleHTMLFileBytes != 10*1024*1024 {
|
||||
t.Fatalf("default=%d, want %d (10MiB)", maxHTMLPublishSingleHTMLFileBytes, 10*1024*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHTMLPublish_RejectsOversizeHTMLFile(t *testing.T) {
|
||||
orig := maxHTMLPublishSingleHTMLFileBytes
|
||||
maxHTMLPublishSingleHTMLFileBytes = 100
|
||||
defer func() { maxHTMLPublishSingleHTMLFileBytes = orig }()
|
||||
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("<html></html>"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "big.html"), []byte(strings.Repeat("x", 4096)), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
fake := &fakeAppsHTMLPublishClient{}
|
||||
_, err := runHTMLPublish(context.Background(), newTestFIO(), fake, appsHTMLPublishSpec{AppID: "app_x", Path: dir})
|
||||
if err == nil {
|
||||
t.Fatalf("expected per-file oversize error")
|
||||
}
|
||||
problem := requireAppsValidationProblem(t, err)
|
||||
if !strings.Contains(problem.Message, "big.html") || !strings.Contains(problem.Message, "10MB") {
|
||||
t.Fatalf("message=%q, want contains 'big.html' and '10MB'", problem.Message)
|
||||
}
|
||||
if problem.Hint == "" {
|
||||
t.Fatalf("expected non-empty hint")
|
||||
}
|
||||
if len(fake.calls) != 0 {
|
||||
t.Fatalf("client must not be called when an HTML file is oversize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunHTMLPublish_IgnoresOversizeNonHTML(t *testing.T) {
|
||||
// 单 .html 上限调小,但超限文件是 .png → 不被本护栏拦截,正常发布。
|
||||
orig := maxHTMLPublishSingleHTMLFileBytes
|
||||
maxHTMLPublishSingleHTMLFileBytes = 100
|
||||
defer func() { maxHTMLPublishSingleHTMLFileBytes = orig }()
|
||||
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("<html></html>"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "big.png"), []byte(strings.Repeat("x", 4096)), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
fake := &fakeAppsHTMLPublishClient{resp: &htmlPublishResponse{URL: "https://miaoda/app_x"}}
|
||||
if _, err := runHTMLPublish(context.Background(), newTestFIO(), fake, appsHTMLPublishSpec{AppID: "app_x", Path: dir}); err != nil {
|
||||
t.Fatalf("non-html oversize must not be blocked by the .html cap: %v", err)
|
||||
}
|
||||
if len(fake.calls) != 1 {
|
||||
t.Fatalf("client should be called; calls=%v", fake.calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,14 +99,7 @@ var AppsInit = common.Shortcut{
|
||||
dry.Set("dir_error", err.Error())
|
||||
dir = defaultCloneDir(appID)
|
||||
} else if isAlreadyInitialized(dir) {
|
||||
if existing, e := ensureInitDirMatchesApp(dir, appID); e != nil {
|
||||
if existing != "" {
|
||||
dry.Set("app_id_mismatch", existing)
|
||||
}
|
||||
dry.Set("dir_error", e.Error())
|
||||
} else {
|
||||
dry.Set("already_initialized", true)
|
||||
}
|
||||
dry.Set("already_initialized", true)
|
||||
} else if e := ensureEmptyDir(dir); e != nil {
|
||||
dry.Set("dir_error", e.Error())
|
||||
}
|
||||
@@ -206,61 +199,6 @@ func isAlreadyInitialized(dir string) bool {
|
||||
return err == nil && !info.IsDir()
|
||||
}
|
||||
|
||||
// readMetaAppID 读取 <dir>/.spark/meta.json 的 app_id,用于判断目标目录是否同一个妙搭应用。
|
||||
// 返回 (appID, isSparkProject, err):
|
||||
// - meta.json 不存在 → ("", false, nil) 非妙搭工程
|
||||
// - 读取/解析失败(损坏/不可读) → ("", false, err) 无法确认是否妙搭工程
|
||||
// - 解析成功 → (trim 后的 app_id, true, nil)(app_id 缺失/为空时为 "")
|
||||
func readMetaAppID(dir string) (string, bool, error) {
|
||||
b, err := os.ReadFile(filepath.Join(dir, metaRelPath)) //nolint:forbidigo // shortcuts cannot import internal/vfs (depguard rule shortcuts-no-vfs); path is under the validated clone dir, and FileIO.Open rejects absolute paths.
|
||||
if os.IsNotExist(err) {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, appsFileIOError(err, "read %s failed: %v", metaRelPath, err)
|
||||
}
|
||||
var m struct {
|
||||
AppID string `json:"app_id"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return "", false, appsFileIOError(err, "parse %s failed: %v", metaRelPath, err)
|
||||
}
|
||||
return strings.TrimSpace(m.AppID), true, nil
|
||||
}
|
||||
|
||||
// ensureInitDirMatchesApp 校验「已存在的目标目录」能否被 appID 安全复用:
|
||||
// - 不是妙搭工程(无 meta.json) → nil(交给 ensureEmptyDir 判空/非空)
|
||||
// - 是妙搭工程且 app_id 与 appID 一致 → nil(走已初始化短路,复用本地代码)
|
||||
// - 是妙搭工程但 app_id 不一致(含为空) → 报错,提示换目录
|
||||
// - meta.json 损坏/不可读,无法确认 → 报错(fail closed),提示换目录
|
||||
//
|
||||
// 返回值 existing 是目录里已存在的 app_id(仅"已是另一个 app"的拒绝场景非空),供调用方在
|
||||
// dry-run 里回填 app_id_mismatch,避免二次读 meta.json。
|
||||
func ensureInitDirMatchesApp(dir, appID string) (existing string, err error) {
|
||||
existing, isSpark, readErr := readMetaAppID(dir)
|
||||
if readErr != nil {
|
||||
return "", appsValidationParamError("--dir",
|
||||
"target directory %q already exists but its %s is unreadable or corrupted; cannot confirm it belongs to app %s, refusing to use it",
|
||||
dir, metaRelPath, appID).
|
||||
WithHint("choose a different --dir, or repair/remove the directory, before running +init").
|
||||
WithCause(readErr)
|
||||
}
|
||||
if !isSpark || existing == appID {
|
||||
return existing, nil
|
||||
}
|
||||
if existing == "" {
|
||||
// meta 存在但缺 app_id:更可能是同一应用上次 +init 中断留下的半成品,而非另一个 app。
|
||||
return "", appsValidationParamError("--dir",
|
||||
"target directory %q has a %s without an app_id; cannot confirm it belongs to app %s, refusing to use it",
|
||||
dir, metaRelPath, appID).
|
||||
WithHint("remove the directory and re-run +init, or choose a different --dir")
|
||||
}
|
||||
return existing, appsValidationParamError("--dir",
|
||||
"target directory %q is already initialized for a different app (%s); refusing to initialize app %s into it",
|
||||
dir, existing, appID).
|
||||
WithHint("choose a different --dir (or cd into the matching project) before running +init")
|
||||
}
|
||||
|
||||
// ensureMetaAppID patches <dir>/.spark/meta.json to include app_id when the file
|
||||
// exists but lacks (or has an empty) app_id. Other fields are preserved. When
|
||||
// the file does not exist, this is a no-op (we never create it).
|
||||
@@ -440,11 +378,6 @@ func appsInitExecute(ctx context.Context, rctx *common.RuntimeContext) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 异 app 目录护栏:拒绝把当前 app 初始化进另一个 app 的已初始化工程。
|
||||
if _, err := ensureInitDirMatchesApp(dir, appID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Already-initialized short-circuit: a dir containing .spark/meta.json is an
|
||||
// initialized app repo -> skip clone/scaffold/commit, but still refresh
|
||||
// the local env so a re-run picks up the latest startup env vars.
|
||||
|
||||
@@ -363,7 +363,7 @@ func TestAppsInit_AlreadyInitialized_ShortCircuit(t *testing.T) {
|
||||
if err := os.MkdirAll(filepath.Join(dir, ".spark"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, metaRelPath), []byte(`{"app_id":"app_x"}`), 0o644); err != nil {
|
||||
if err := os.WriteFile(filepath.Join(dir, metaRelPath), []byte(`{"app_id":"whatever"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f := &fakeCommandRunner{results: map[string]fakeCallResult{"env-pull": envPullOK(filepath.Join(abs, ".env.local"))}}
|
||||
@@ -394,40 +394,6 @@ func TestAppsInit_AlreadyInitialized_ShortCircuit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppsInit_AlreadyInitialized_AppIDMismatch(t *testing.T) {
|
||||
dir := relCloneDir(t)
|
||||
if err := os.MkdirAll(filepath.Join(dir, ".spark"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 目录是 app_other 的工程,却用 --app-id app_x 初始化 → 必须报错且不拉 env。
|
||||
if err := os.WriteFile(filepath.Join(dir, metaRelPath), []byte(`{"app_id":"app_other"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f := &fakeCommandRunner{}
|
||||
withFakeRunner(t, f)
|
||||
factory, stdout, _ := newAppsExecuteFactory(t)
|
||||
err := runAppsShortcut(t, AppsInit, []string{"+init", "--app-id", "app_x", "--dir", dir, "--as", "user"}, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("mismatched app_id must error")
|
||||
}
|
||||
problem := requireAppsValidationProblem(t, err)
|
||||
if problem.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Fatalf("subtype=%q, want %q", problem.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
var ve *errs.ValidationError
|
||||
if !errors.As(err, &ve) || ve.Param != "--dir" {
|
||||
t.Fatalf("expected *errs.ValidationError with Param=--dir, got %T param=%v", err, ve)
|
||||
}
|
||||
if !strings.Contains(problem.Message, "different app") {
|
||||
t.Fatalf("message=%q, want 'different app'", problem.Message)
|
||||
}
|
||||
for _, c := range f.calls {
|
||||
if containsAll(c, "+env-pull") || containsAll(c, "git", "clone") {
|
||||
t.Errorf("mismatch must not run env-pull/clone; got %v", f.calls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppsInit_HappyPathCleanTree(t *testing.T) {
|
||||
f := &fakeCommandRunner{results: map[string]fakeCallResult{
|
||||
"credential-init": credInitOK("http://u:t@h/app_x.git"),
|
||||
@@ -1502,125 +1468,6 @@ func TestAppsInit_Description_IsAboutCode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMetaAppID(t *testing.T) {
|
||||
writeMeta := func(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
if err := os.MkdirAll(filepath.Join(dir, ".spark"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, metaRelPath), []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
// 不存在 meta.json → ("", false, nil)
|
||||
if got, ok, err := readMetaAppID(t.TempDir()); ok || got != "" || err != nil {
|
||||
t.Fatalf("no meta: got (%q,%v,%v), want (\"\",false,nil)", got, ok, err)
|
||||
}
|
||||
// 存在且有 app_id → (app_id, true, nil)
|
||||
if got, ok, err := readMetaAppID(writeMeta(t, `{"app_id":"app_a"}`)); !ok || got != "app_a" || err != nil {
|
||||
t.Fatalf("with app_id: got (%q,%v,%v), want (\"app_a\",true,nil)", got, ok, err)
|
||||
}
|
||||
// 存在但 app_id 空 → ("", true, nil)
|
||||
if got, ok, err := readMetaAppID(writeMeta(t, `{"name":"x"}`)); !ok || got != "" || err != nil {
|
||||
t.Fatalf("empty app_id: got (%q,%v,%v), want (\"\",true,nil)", got, ok, err)
|
||||
}
|
||||
// 存在但坏 JSON → ("", false, err)(无法确认)
|
||||
if got, ok, err := readMetaAppID(writeMeta(t, `{not json`)); ok || got != "" || err == nil {
|
||||
t.Fatalf("bad json: got (%q,%v,err=%v), want (\"\",false,non-nil)", got, ok, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureInitDirMatchesApp(t *testing.T) {
|
||||
writeMeta := func(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
if err := os.MkdirAll(filepath.Join(dir, ".spark"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, metaRelPath), []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
// 无 meta(非妙搭工程)→ nil(交给 ensureEmptyDir)
|
||||
if _, err := ensureInitDirMatchesApp(t.TempDir(), "app_x"); err != nil {
|
||||
t.Fatalf("no meta should pass: %v", err)
|
||||
}
|
||||
// 同 app_id → (app_id, nil)(走已初始化短路)
|
||||
if existing, err := ensureInitDirMatchesApp(writeMeta(t, `{"app_id":"app_x"}`), "app_x"); err != nil || existing != "app_x" {
|
||||
t.Fatalf("same app should pass: existing=%q err=%v", existing, err)
|
||||
}
|
||||
|
||||
// 不同 app_id → error(换目录),返回 existing=app_other;断言 typed metadata(subtype/param)
|
||||
existing, errMismatch := ensureInitDirMatchesApp(writeMeta(t, `{"app_id":"app_other"}`), "app_x")
|
||||
if errMismatch == nil {
|
||||
t.Fatal("different app should error")
|
||||
}
|
||||
if existing != "app_other" {
|
||||
t.Fatalf("mismatch should return existing app_id, got %q", existing)
|
||||
}
|
||||
problem := requireAppsValidationProblem(t, errMismatch) // 已校验 Category==Validation
|
||||
if problem.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Fatalf("subtype=%q, want %q", problem.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
var ve *errs.ValidationError
|
||||
if !errors.As(errMismatch, &ve) {
|
||||
t.Fatalf("expected *errs.ValidationError, got %T", errMismatch)
|
||||
}
|
||||
if ve.Param != "--dir" {
|
||||
t.Fatalf("param=%q, want --dir", ve.Param)
|
||||
}
|
||||
if !strings.Contains(problem.Message, "different app") || !strings.Contains(problem.Message, "app_other") {
|
||||
t.Fatalf("message=%q, want 'different app' and 'app_other'", problem.Message)
|
||||
}
|
||||
if !strings.Contains(problem.Hint, "different --dir") {
|
||||
t.Fatalf("hint=%q, want 'different --dir'", problem.Hint)
|
||||
}
|
||||
|
||||
// 空 app_id(缺 app_id 标记的半成品)→ error,独立文案(非 "different app"),返回 existing=""
|
||||
emptyExisting, errEmpty := ensureInitDirMatchesApp(writeMeta(t, `{"name":"x"}`), "app_x")
|
||||
if errEmpty == nil {
|
||||
t.Fatal("empty meta app_id should error (cannot confirm same app)")
|
||||
}
|
||||
if emptyExisting != "" {
|
||||
t.Fatalf("empty app_id should return existing=\"\", got %q", emptyExisting)
|
||||
}
|
||||
pEmpty := requireAppsValidationProblem(t, errEmpty)
|
||||
if pEmpty.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Fatalf("empty subtype=%q, want %q", pEmpty.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if !strings.Contains(pEmpty.Message, "without an app_id") {
|
||||
t.Fatalf("empty app_id should have its own message, msg=%q", pEmpty.Message)
|
||||
}
|
||||
if strings.Contains(pEmpty.Message, "different app") {
|
||||
t.Fatalf("empty app_id must not reuse the different-app wording, msg=%q", pEmpty.Message)
|
||||
}
|
||||
|
||||
// meta 损坏/不可读 → error(fail closed),返回 existing=""
|
||||
badExisting, errBad := ensureInitDirMatchesApp(writeMeta(t, `{not json`), "app_x")
|
||||
if errBad == nil {
|
||||
t.Fatal("corrupted meta should fail closed")
|
||||
}
|
||||
if badExisting != "" {
|
||||
t.Fatalf("corrupted should return existing=\"\", got %q", badExisting)
|
||||
}
|
||||
pBad := requireAppsValidationProblem(t, errBad)
|
||||
if pBad.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Fatalf("corrupted subtype=%q, want %q", pBad.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if !strings.Contains(pBad.Message, "unreadable or corrupted") {
|
||||
t.Fatalf("corrupted meta msg=%q, want 'unreadable or corrupted'", pBad.Message)
|
||||
}
|
||||
var veBad *errs.ValidationError
|
||||
if !errors.As(errBad, &veBad) || veBad.Param != "--dir" {
|
||||
t.Fatalf("corrupted: expected ValidationError Param=--dir, got %T param=%v", errBad, veBad)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunScaffold_SubprocessFailureIsExternalTool pins the typed
|
||||
// classification of an external-tool failure: a failing git subprocess
|
||||
// surfaces as internal/external_tool with the cause preserved.
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -33,7 +32,6 @@ import (
|
||||
)
|
||||
|
||||
const gitCredentialIssuePath = apiBasePath + "/apps/:app_id/git_info"
|
||||
const gitCredentialHelperReportedShortcut = appsService + ":+git-credential-helper"
|
||||
|
||||
// gitCredentialIssueHint is the actionable next-step attached to a failed
|
||||
// Git-credential issuance. A 5xx is flagged retryable separately at the call site.
|
||||
@@ -304,12 +302,7 @@ func (i factoryIssuer) Issue(ctx context.Context, appID string, profile gitcred.
|
||||
HttpMethod: http.MethodGet,
|
||||
ApiPath: issuePath(appID),
|
||||
}
|
||||
ctx = contextWithGitCredentialHelperShortcut(ctx)
|
||||
var opts []larkcore.RequestOptionFunc
|
||||
if optFn := cmdutil.ShortcutHeaderOpts(ctx); optFn != nil {
|
||||
opts = append(opts, optFn)
|
||||
}
|
||||
resp, err := ac.DoSDKRequest(ctx, req, core.AsUser, opts...)
|
||||
resp, err := ac.DoSDKRequest(ctx, req, core.AsUser)
|
||||
data, err := parseIssueCredentialData(resp, err, errclass.ClassifyContext{
|
||||
Brand: string(cfg.Brand),
|
||||
AppID: cfg.AppID,
|
||||
@@ -321,13 +314,6 @@ func (i factoryIssuer) Issue(ctx context.Context, appID string, profile gitcred.
|
||||
return issuedFromData(appID, data)
|
||||
}
|
||||
|
||||
func contextWithGitCredentialHelperShortcut(ctx context.Context) context.Context {
|
||||
if _, ok := cmdutil.ShortcutNameFromContext(ctx); ok {
|
||||
return ctx
|
||||
}
|
||||
return cmdutil.ContextWithShortcut(ctx, gitCredentialHelperReportedShortcut, uuid.New().String())
|
||||
}
|
||||
|
||||
func runGitCredentialHelper(ctx context.Context, f *cmdutil.Factory, appID, action string) error {
|
||||
if f == nil || f.IOStreams == nil {
|
||||
return nil
|
||||
|
||||
@@ -825,7 +825,7 @@ func TestRunGitCredentialHelperActions(t *testing.T) {
|
||||
func TestFactoryIssuerBranches(t *testing.T) {
|
||||
factory, _, reg := newAppsExecuteFactory(t)
|
||||
expiresAt := time.Now().Add(24 * time.Hour).Unix()
|
||||
issueStub := &httpmock.Stub{
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/spark/v1/apps/app_xxx/git_info",
|
||||
Body: map[string]interface{}{
|
||||
@@ -836,8 +836,7 @@ func TestFactoryIssuerBranches(t *testing.T) {
|
||||
"StatusCode": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(issueStub)
|
||||
})
|
||||
issued, err := (factoryIssuer{f: factory}).Issue(context.Background(), "app_xxx", gitcred.ProfileContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("factory issuer returned error: %v", err)
|
||||
@@ -845,12 +844,6 @@ func TestFactoryIssuerBranches(t *testing.T) {
|
||||
if issued.PAT != "pat-token" {
|
||||
t.Fatalf("PAT = %q", issued.PAT)
|
||||
}
|
||||
if got := issueStub.CapturedHeaders.Get(cmdutil.HeaderShortcut); got != gitCredentialHelperReportedShortcut {
|
||||
t.Fatalf("%s = %q, want %q", cmdutil.HeaderShortcut, got, gitCredentialHelperReportedShortcut)
|
||||
}
|
||||
if got := issueStub.CapturedHeaders.Get(cmdutil.HeaderExecutionId); got == "" {
|
||||
t.Fatalf("%s header missing", cmdutil.HeaderExecutionId)
|
||||
}
|
||||
|
||||
factory.Config = func() (*core.CliConfig, error) { return nil, errors.New("config failed") }
|
||||
if _, err := (factoryIssuer{f: factory}).Issue(context.Background(), "app_xxx", gitcred.ProfileContext{}); err == nil {
|
||||
@@ -887,20 +880,6 @@ func TestFactoryIssuerBranches(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextWithGitCredentialHelperShortcutPreservesExistingShortcut(t *testing.T) {
|
||||
ctx := cmdutil.ContextWithShortcut(context.Background(), "apps:+git-credential-init", "exec-existing")
|
||||
got := contextWithGitCredentialHelperShortcut(ctx)
|
||||
|
||||
name, ok := cmdutil.ShortcutNameFromContext(got)
|
||||
if !ok || name != "apps:+git-credential-init" {
|
||||
t.Fatalf("shortcut = %q ok=%v, want existing shortcut", name, ok)
|
||||
}
|
||||
executionID, ok := cmdutil.ExecutionIdFromContext(got)
|
||||
if !ok || executionID != "exec-existing" {
|
||||
t.Fatalf("execution id = %q ok=%v, want existing execution id", executionID, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitCredentialHelpersAndParsers(t *testing.T) {
|
||||
if issuePath(" app/with space ") != "/open-apis/spark/v1/apps/app%2Fwith%20space/git_info" {
|
||||
t.Fatalf("issuePath escaped incorrectly: %s", issuePath(" app/with space "))
|
||||
|
||||
@@ -85,7 +85,6 @@ type searchUserAPIData struct {
|
||||
Items []searchUserAPIItem `json:"items"`
|
||||
HasMore bool `json:"has_more"`
|
||||
PageToken string `json:"page_token"`
|
||||
Notice string `json:"notice"`
|
||||
}
|
||||
|
||||
type searchUserAPIItem struct {
|
||||
@@ -127,7 +126,6 @@ type searchUser struct {
|
||||
type searchUserResponse struct {
|
||||
Users []searchUser `json:"users"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Notice string `json:"notice,omitempty"`
|
||||
}
|
||||
|
||||
var ContactSearchUser = common.Shortcut{
|
||||
@@ -191,7 +189,6 @@ var ContactSearchUser = common.Shortcut{
|
||||
Execute: executeSearchUser,
|
||||
}
|
||||
|
||||
// executeSearchUser dispatches contact search to single-query or fanout mode.
|
||||
func executeSearchUser(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
if strings.TrimSpace(runtime.Str("queries")) != "" {
|
||||
return executeSearchUserFanout(ctx, runtime)
|
||||
@@ -199,7 +196,6 @@ func executeSearchUser(ctx context.Context, runtime *common.RuntimeContext) erro
|
||||
return executeSearchUserSingle(ctx, runtime)
|
||||
}
|
||||
|
||||
// executeSearchUserSingle performs one contact search and preserves server notices.
|
||||
func executeSearchUserSingle(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
body, err := buildSearchUserBody(runtime)
|
||||
if err != nil {
|
||||
@@ -226,7 +222,7 @@ func executeSearchUserSingle(ctx context.Context, runtime *common.RuntimeContext
|
||||
}
|
||||
|
||||
users, hasMore := projectUsers(respData, runtime.Str("lang"), runtime.Config.Brand)
|
||||
out := searchUserResponse{Users: users, HasMore: hasMore, Notice: respData.Notice}
|
||||
out := searchUserResponse{Users: users, HasMore: hasMore}
|
||||
|
||||
runtime.OutFormat(out, &output.Meta{Count: len(users)}, func(w io.Writer) {
|
||||
if len(users) == 0 {
|
||||
|
||||
@@ -45,17 +45,22 @@ type fanoutResult struct {
|
||||
Query string
|
||||
Users []searchUser
|
||||
HasMore bool
|
||||
Notice string
|
||||
ErrMsg string // empty = success
|
||||
Err error // original failure, kept for typed all-failed propagation
|
||||
}
|
||||
|
||||
// isFanoutSummaryFormat gates the per-fanout stderr summary line.
|
||||
// isFanoutSummaryFormat gates the per-fanout stderr summary line. Includes csv
|
||||
// because that summary lives on stderr and never corrupts the csv stream on
|
||||
// stdout — single-query mode keeps the narrower isHumanReadableFormat predicate
|
||||
// for its refine hint, so adding csv here doesn't regress that path.
|
||||
func isFanoutSummaryFormat(format string) bool {
|
||||
return format == "pretty" || format == "table" || format == "csv"
|
||||
}
|
||||
|
||||
// runOneQuery converts one fanout request into either users or an error summary.
|
||||
// runOneQuery converts every failure mode (transport, HTTP status, parse,
|
||||
// API code) into an ErrMsg string instead of returning a Go error. The
|
||||
// fanout dispatcher (Task 6) relies on this so a single failed query never
|
||||
// short-circuits the remaining workers.
|
||||
func runOneQuery(ctx context.Context, runtime *common.RuntimeContext, index int, query string,
|
||||
filter *searchUserAPIFilter) fanoutResult {
|
||||
// Pre-check ctx so queued workers see cancellation before issuing a
|
||||
@@ -89,10 +94,9 @@ func runOneQuery(ctx context.Context, runtime *common.RuntimeContext, index int,
|
||||
}
|
||||
|
||||
users, hasMore := projectUsers(respData, runtime.Str("lang"), runtime.Config.Brand)
|
||||
return fanoutResult{Index: index, Query: query, Users: users, HasMore: hasMore, Notice: respData.Notice}
|
||||
return fanoutResult{Index: index, Query: query, Users: users, HasMore: hasMore}
|
||||
}
|
||||
|
||||
// fanoutErrorResult records a failed fanout query without stopping other workers.
|
||||
func fanoutErrorResult(index int, query string, err error) fanoutResult {
|
||||
if err == nil {
|
||||
return fanoutResult{Index: index, Query: query}
|
||||
@@ -109,16 +113,17 @@ type querySummary struct {
|
||||
Query string `json:"query"`
|
||||
Error string `json:"error,omitempty"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Notice string `json:"notice,omitempty"`
|
||||
}
|
||||
|
||||
type fanoutResponse struct {
|
||||
Users []fanoutUser `json:"users"`
|
||||
Queries []querySummary `json:"queries"`
|
||||
Notice string `json:"notice,omitempty"`
|
||||
}
|
||||
|
||||
// buildFanoutResponse flattens ordered fanout results and fails only when all queries fail.
|
||||
// buildFanoutResponse walks results by Index (input order), flattens users[]
|
||||
// with matched_query, lists every input in queries[] (including successes),
|
||||
// and returns an error only when every query failed. The error wraps the
|
||||
// first failing query's ErrMsg so the CLI exits non-zero on full failure.
|
||||
func buildFanoutResponse(queries []string, results []fanoutResult) (*fanoutResponse, error) {
|
||||
indexed := make([]fanoutResult, len(queries))
|
||||
for _, r := range results {
|
||||
@@ -137,7 +142,6 @@ func buildFanoutResponse(queries []string, results []fanoutResult) (*fanoutRespo
|
||||
Query: queries[i],
|
||||
Error: r.ErrMsg,
|
||||
HasMore: r.HasMore,
|
||||
Notice: r.Notice,
|
||||
})
|
||||
if r.ErrMsg != "" {
|
||||
failed++
|
||||
@@ -148,9 +152,6 @@ func buildFanoutResponse(queries []string, results []fanoutResult) (*fanoutRespo
|
||||
}
|
||||
continue
|
||||
}
|
||||
if out.Notice == "" {
|
||||
out.Notice = r.Notice
|
||||
}
|
||||
for _, u := range r.Users {
|
||||
out.Users = append(out.Users, fanoutUser{searchUser: u, MatchedQuery: queries[i]})
|
||||
}
|
||||
|
||||
@@ -562,7 +562,6 @@ func mountAndRun(t *testing.T, s common.Shortcut, args []string, f *cmdutil.Fact
|
||||
return parent.Execute()
|
||||
}
|
||||
|
||||
// searchUserStub returns a representative user search response with a notice.
|
||||
func searchUserStub() *httpmock.Stub {
|
||||
return &httpmock.Stub{
|
||||
Method: "POST",
|
||||
@@ -570,7 +569,6 @@ func searchUserStub() *httpmock.Stub {
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": "The query is too long and has been truncated to the first 50 characters for search.",
|
||||
"items": []interface{}{
|
||||
map[string]interface{}{
|
||||
"id": "ou_a",
|
||||
@@ -592,7 +590,6 @@ func searchUserStub() *httpmock.Stub {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchUser_Integration_PrettyRendersExpectedColumns verifies human output columns.
|
||||
func TestSearchUser_Integration_PrettyRendersExpectedColumns(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, searchUserDefaultConfig())
|
||||
reg.Register(searchUserStub())
|
||||
@@ -617,7 +614,6 @@ func TestSearchUser_Integration_PrettyRendersExpectedColumns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchUser_Integration_JSONStructuredFields verifies normalized JSON and notices.
|
||||
func TestSearchUser_Integration_JSONStructuredFields(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, searchUserDefaultConfig())
|
||||
reg.Register(searchUserStub())
|
||||
@@ -635,9 +631,6 @@ func TestSearchUser_Integration_JSONStructuredFields(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("envelope.data: expected object, got %v\nraw=%s", got["data"], stdout.String())
|
||||
}
|
||||
if data["notice"] != "The query is too long and has been truncated to the first 50 characters for search." {
|
||||
t.Fatalf("data.notice = %v", data["notice"])
|
||||
}
|
||||
users, _ := data["users"].([]interface{})
|
||||
if len(users) != 1 {
|
||||
t.Fatalf("users: expected 1, got %d (output=%s)", len(users), stdout.String())
|
||||
@@ -1365,7 +1358,6 @@ func TestSearchUser_Integration_NoAutoPaginationFlags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFanout_FilterAppliedToEachQuery verifies shared fanout filters reach every request.
|
||||
func TestFanout_FilterAppliedToEachQuery(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, searchUserDefaultConfig())
|
||||
stub := &httpmock.Stub{
|
||||
@@ -1407,7 +1399,6 @@ func TestFanout_FilterAppliedToEachQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFanout_PartialFailure_ExitZero verifies partial fanout failures keep notices.
|
||||
func TestFanout_PartialFailure_ExitZero(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, searchUserDefaultConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
@@ -1415,7 +1406,6 @@ func TestFanout_PartialFailure_ExitZero(t *testing.T) {
|
||||
BodyFilter: func(b []byte) bool { return strings.Contains(string(b), `"alice"`) },
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": "The query is too long and has been truncated to the first 50 characters for search.",
|
||||
"items": []interface{}{map[string]interface{}{"id": "ou_a"}},
|
||||
"has_more": false,
|
||||
}},
|
||||
@@ -1442,17 +1432,10 @@ func TestFanout_PartialFailure_ExitZero(t *testing.T) {
|
||||
if len(users) != 1 {
|
||||
t.Errorf("users: expected 1 (alice), got %d; stdout=%s", len(users), stdout.String())
|
||||
}
|
||||
if data["notice"] != "The query is too long and has been truncated to the first 50 characters for search." {
|
||||
t.Fatalf("data.notice = %v", data["notice"])
|
||||
}
|
||||
queries := data["queries"].([]interface{})
|
||||
if len(queries) != 2 {
|
||||
t.Fatalf("queries: expected 2, got %d", len(queries))
|
||||
}
|
||||
q0 := queries[0].(map[string]interface{})
|
||||
if q0["notice"] != "The query is too long and has been truncated to the first 50 characters for search." {
|
||||
t.Fatalf("queries[0].notice = %v", q0["notice"])
|
||||
}
|
||||
q1 := queries[1].(map[string]interface{})
|
||||
if !strings.HasPrefix(q1["error"].(string), "HTTP 500") {
|
||||
t.Errorf("queries[1].error: got %q", q1["error"])
|
||||
|
||||
@@ -74,9 +74,6 @@ var DocsSearch = common.Shortcut{
|
||||
"page_token": data["page_token"],
|
||||
"results": normalizedItems,
|
||||
}
|
||||
if notice, _ := data["notice"].(string); notice != "" {
|
||||
resultData["notice"] = notice
|
||||
}
|
||||
|
||||
runtime.OutFormat(resultData, &output.Meta{Count: len(normalizedItems)}, func(w io.Writer) {
|
||||
if len(normalizedItems) == 0 {
|
||||
|
||||
@@ -7,48 +7,8 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
)
|
||||
|
||||
// TestDocsSearchExecutePassesThroughNotice verifies docs +search preserves notices.
|
||||
func TestDocsSearchExecutePassesThroughNotice(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("docs-search-notice"))
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/search/v2/doc_wiki/search",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"res_units": []interface{}{},
|
||||
"total": 0,
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := mountAndRunDocs(t, DocsSearch, []string{"+search", "--query", "incident", "--format", "json", "--as", "user"}, f, stdout); err != nil {
|
||||
t.Fatalf("DocsSearch.Execute() error = %v", err)
|
||||
}
|
||||
reg.Verify(t)
|
||||
|
||||
var env map[string]interface{}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &env); err != nil {
|
||||
t.Fatalf("json.Unmarshal(stdout) error = %v\nstdout=%s", err, stdout.String())
|
||||
}
|
||||
data, _ := env["data"].(map[string]interface{})
|
||||
if got, _ := data["notice"].(string); got != notice {
|
||||
t.Fatalf("data.notice = %q, want %q; data=%#v", got, notice, data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddIsoTimeFieldsSupportsJSONNumber verifies JSON numbers get ISO fields.
|
||||
func TestAddIsoTimeFieldsSupportsJSONNumber(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ const (
|
||||
var DriveAddComment = common.Shortcut{
|
||||
Service: "drive",
|
||||
Command: "+add-comment",
|
||||
Description: "Add a comment to doc/docx/file/sheet/slides/base(bitable); file targets support selected extensions and full comments only",
|
||||
Description: "Add a comment to doc/docx/file/sheet/slides; file targets support selected extensions and full comments only",
|
||||
Risk: "write",
|
||||
Scopes: []string{
|
||||
"drive:drive.metadata:readonly",
|
||||
@@ -131,12 +131,12 @@ var DriveAddComment = common.Shortcut{
|
||||
},
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
Flags: []common.Flag{
|
||||
{Name: "doc", Desc: "document URL/token, file URL/token, sheet/slides/base/bitable URL, or wiki URL that resolves to doc/docx/file/sheet/slides/base(bitable)", Required: true},
|
||||
{Name: "type", Desc: "document type: doc, docx, file, sheet, slides, bitable, base (required when --doc is a bare token; auto-detected for URLs; use bitable as the wire value, base is accepted as a compatibility alias)", Enum: []string{"doc", "docx", "file", "sheet", "slides", "bitable", "base"}},
|
||||
{Name: "doc", Desc: "document URL/token, file URL/token, sheet/slides URL, or wiki URL that resolves to doc/docx/file/sheet/slides", Required: true},
|
||||
{Name: "type", Desc: "document type: doc, docx, file, sheet, slides (required when --doc is a bare token; auto-detected for URLs)", Enum: []string{"doc", "docx", "file", "sheet", "slides"}},
|
||||
{Name: "content", Desc: "reply_elements JSON string", Required: true, Input: []string{common.File, common.Stdin}},
|
||||
{Name: "full-comment", Type: "bool", Desc: "create a full-document comment; also the default when no location is provided"},
|
||||
{Name: "selection-with-ellipsis", Desc: "target content locator (plain text or 'start...end')"},
|
||||
{Name: "block-id", Desc: "for docx: anchor block ID; for sheet: <sheetId>!<cell>; for slides: <slide-block-type>!<xml-id>; for base(bitable): <table-id>!<record-id>!<view-id>"},
|
||||
{Name: "block-id", Desc: "for docx: anchor block ID; for sheet: <sheetId>!<cell> (e.g. a281f9!D6); for slides: <slide-block-type>!<xml-id> (e.g. shape!bPq)"},
|
||||
},
|
||||
Validate: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
docRef, err := parseCommentDocRef(runtime.Str("doc"), runtime.Str("type"))
|
||||
@@ -148,17 +148,6 @@ var DriveAddComment = common.Shortcut{
|
||||
return err
|
||||
}
|
||||
|
||||
if docRef.Kind == "base" {
|
||||
if runtime.Bool("full-comment") {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--full-comment is not applicable for base(bitable) comments; use --block-id <table-id>!<record-id>!<view-id>").WithParam("--full-comment")
|
||||
}
|
||||
if strings.TrimSpace(runtime.Str("selection-with-ellipsis")) != "" {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--selection-with-ellipsis is not applicable for base(bitable) comments; use --block-id <table-id>!<record-id>!<view-id>").WithParam("--selection-with-ellipsis")
|
||||
}
|
||||
_, err := parseBaseCommentAnchor(runtime)
|
||||
return err
|
||||
}
|
||||
|
||||
// Sheet comment validation.
|
||||
if docRef.Kind == "sheet" {
|
||||
blockID := strings.TrimSpace(runtime.Str("block-id"))
|
||||
@@ -199,7 +188,7 @@ var DriveAddComment = common.Shortcut{
|
||||
return validateFileCommentMode(mode, "")
|
||||
}
|
||||
if mode == commentModeLocal && docRef.Kind == "doc" {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "local comments only support docx, sheet, slides, and base(bitable); old doc format only supports full comments")
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "local comments only support docx, sheet, and slides; old doc format only supports full comments")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -226,23 +215,6 @@ var DriveAddComment = common.Shortcut{
|
||||
resolvedToken = target.FileToken
|
||||
}
|
||||
|
||||
if resolvedKind == "base" {
|
||||
anchor, err := parseBaseCommentAnchor(runtime)
|
||||
if err != nil {
|
||||
return common.NewDryRunAPI().Set("error", err.Error())
|
||||
}
|
||||
commentBody := buildBaseCommentCreateV2Request(replyElements, anchor)
|
||||
desc := "1-step request: create base(bitable) record-local comment"
|
||||
if isWiki {
|
||||
desc = "2-step orchestration: resolve wiki -> create base(bitable) record-local comment"
|
||||
}
|
||||
return common.NewDryRunAPI().
|
||||
Desc(desc).
|
||||
POST("/open-apis/drive/v1/files/:file_token/new_comments").
|
||||
Body(commentBody).
|
||||
Set("file_token", resolvedToken)
|
||||
}
|
||||
|
||||
// Sheet comment dry-run.
|
||||
if resolvedKind == "sheet" {
|
||||
anchor, _ := parseSheetCellRef(blockID)
|
||||
@@ -380,14 +352,6 @@ var DriveAddComment = common.Shortcut{
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
// Sheet comment: direct URL or token fast path.
|
||||
docRef, _ := parseCommentDocRef(runtime.Str("doc"), runtime.Str("type"))
|
||||
if docRef.Kind == "base" {
|
||||
return executeBaseComment(runtime, resolvedCommentTarget{
|
||||
DocID: docRef.Token,
|
||||
FileToken: docRef.Token,
|
||||
FileType: "base",
|
||||
ResolvedBy: "base",
|
||||
})
|
||||
}
|
||||
if docRef.Kind == "sheet" {
|
||||
return executeSheetComment(runtime, docRef)
|
||||
}
|
||||
@@ -411,9 +375,6 @@ var DriveAddComment = common.Shortcut{
|
||||
if target.FileType == "slides" {
|
||||
return executeSlidesComment(runtime, commentDocRef{Kind: "slides", Token: target.FileToken})
|
||||
}
|
||||
if target.FileType == "base" {
|
||||
return executeBaseComment(runtime, target)
|
||||
}
|
||||
if target.FileType == "file" {
|
||||
return executeFileComment(runtime, target)
|
||||
}
|
||||
@@ -521,12 +482,6 @@ func parseCommentDocRef(input, docType string) (commentDocRef, error) {
|
||||
if token, ok := extractURLToken(raw, "/sheets/"); ok {
|
||||
return commentDocRef{Kind: "sheet", Token: token}, nil
|
||||
}
|
||||
if token, ok := extractURLToken(raw, "/base/"); ok {
|
||||
return commentDocRef{Kind: "base", Token: token}, nil
|
||||
}
|
||||
if token, ok := extractURLToken(raw, "/bitable/"); ok {
|
||||
return commentDocRef{Kind: "base", Token: token}, nil
|
||||
}
|
||||
if token, ok := extractURLToken(raw, "/file/"); ok {
|
||||
return commentDocRef{Kind: "file", Token: token}, nil
|
||||
}
|
||||
@@ -540,7 +495,7 @@ func parseCommentDocRef(input, docType string) (commentDocRef, error) {
|
||||
return commentDocRef{Kind: "doc", Token: token}, nil
|
||||
}
|
||||
if strings.Contains(raw, "://") {
|
||||
return commentDocRef{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "unsupported --doc input %q: use a doc/docx/file/sheet/slides/base/bitable URL, a token with --type, or a wiki URL that resolves to doc/docx/file/sheet/slides/base(bitable)", raw).WithParam("--doc")
|
||||
return commentDocRef{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "unsupported --doc input %q: use a doc/docx/file/sheet/slides URL, a token with --type, or a wiki URL that resolves to doc/docx/file/sheet/slides", raw).WithParam("--doc")
|
||||
}
|
||||
if strings.ContainsAny(raw, "/?#") {
|
||||
return commentDocRef{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "unsupported --doc input %q: use a token with --type, or a wiki URL", raw).WithParam("--doc")
|
||||
@@ -549,10 +504,7 @@ func parseCommentDocRef(input, docType string) (commentDocRef, error) {
|
||||
// Bare token: --type is required.
|
||||
docType = strings.TrimSpace(docType)
|
||||
if docType == "" {
|
||||
return commentDocRef{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "--type is required when --doc is a bare token (allowed values: doc, docx, file, sheet, slides, bitable, base; use bitable as the wire value, base is accepted as a compatibility alias)").WithParam("--type")
|
||||
}
|
||||
if docType == "bitable" || docType == "base" {
|
||||
return commentDocRef{Kind: "base", Token: raw}, nil
|
||||
return commentDocRef{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "--type is required when --doc is a bare token (allowed values: doc, docx, file, sheet, slides)").WithParam("--type")
|
||||
}
|
||||
return commentDocRef{Kind: docType, Token: raw}, nil
|
||||
}
|
||||
@@ -563,11 +515,11 @@ func resolveCommentTarget(ctx context.Context, runtime *common.RuntimeContext, i
|
||||
return resolvedCommentTarget{}, err
|
||||
}
|
||||
|
||||
if docRef.Kind == "docx" || docRef.Kind == "doc" || docRef.Kind == "file" || docRef.Kind == "sheet" || docRef.Kind == "slides" || docRef.Kind == "base" {
|
||||
if docRef.Kind == "docx" || docRef.Kind == "doc" || docRef.Kind == "file" || docRef.Kind == "sheet" || docRef.Kind == "slides" {
|
||||
if mode == commentModeLocal {
|
||||
switch docRef.Kind {
|
||||
case "doc":
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "local comments only support docx, sheet, slides, and base(bitable); old doc format only supports full comments")
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "local comments only support docx, sheet, and slides; old doc format only supports full comments")
|
||||
case "file":
|
||||
if err := validateFileCommentMode(mode, ""); err != nil {
|
||||
return resolvedCommentTarget{}, err
|
||||
@@ -605,22 +557,6 @@ func resolveCommentTarget(ctx context.Context, runtime *common.RuntimeContext, i
|
||||
if objType == "slides" && strings.TrimSpace(runtime.Str("selection-with-ellipsis")) != "" {
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but --selection-with-ellipsis is not applicable for slide comments; use --block-id <slide-block-type>!<xml-id>", objType)
|
||||
}
|
||||
if objType == "bitable" || objType == "base" {
|
||||
if runtime.Bool("full-comment") {
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but --full-comment is not applicable for base(bitable) comments; use --block-id <table-id>!<record-id>!<view-id>", objType).WithParam("--full-comment")
|
||||
}
|
||||
if strings.TrimSpace(runtime.Str("selection-with-ellipsis")) != "" {
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but --selection-with-ellipsis is not applicable for base(bitable) comments; use --block-id <table-id>!<record-id>!<view-id>", objType).WithParam("--selection-with-ellipsis")
|
||||
}
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Resolved wiki to base: %s\n", common.MaskToken(objToken))
|
||||
return resolvedCommentTarget{
|
||||
DocID: objToken,
|
||||
FileToken: objToken,
|
||||
FileType: "base",
|
||||
ResolvedBy: "wiki",
|
||||
WikiToken: docRef.Token,
|
||||
}, nil
|
||||
}
|
||||
if objType == "sheet" {
|
||||
// Sheet comments are handled via the sheet fast path in Execute.
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Resolved wiki to %s: %s\n", objType, common.MaskToken(objToken))
|
||||
@@ -656,10 +592,10 @@ func resolveCommentTarget(ctx context.Context, runtime *common.RuntimeContext, i
|
||||
}, nil
|
||||
}
|
||||
if mode == commentModeLocal && objType != "docx" {
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but local comments only support docx, sheet, slides, and base(bitable); for sheet use --block-id <sheetId>!<cell>, for slides use --block-id <slide-block-type>!<xml-id>, for base use --block-id <table-id>!<record-id>!<view-id>", objType)
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but local comments only support docx, sheet, and slides; for sheet use --block-id <sheetId>!<cell>, for slides use --block-id <slide-block-type>!<xml-id>", objType)
|
||||
}
|
||||
if mode == commentModeFull && objType != "docx" && objType != "doc" {
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but comments only support doc/docx/file/sheet/slides/base(bitable)", objType)
|
||||
return resolvedCommentTarget{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "wiki resolved to %q, but comments only support doc/docx/file/sheet/slides", objType)
|
||||
}
|
||||
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Resolved wiki to %s: %s\n", objType, common.MaskToken(objToken))
|
||||
@@ -851,12 +787,6 @@ type sheetAnchor struct {
|
||||
Row int
|
||||
}
|
||||
|
||||
type baseAnchor struct {
|
||||
BlockID string
|
||||
BaseRecordID string
|
||||
BaseViewID string
|
||||
}
|
||||
|
||||
func buildCommentCreateV2Request(fileType, blockID, slideBlockType string, replyElements []map[string]interface{}, sheet *sheetAnchor) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"file_type": fileType,
|
||||
@@ -883,18 +813,6 @@ func buildCommentCreateV2Request(fileType, blockID, slideBlockType string, reply
|
||||
return body
|
||||
}
|
||||
|
||||
func buildBaseCommentCreateV2Request(replyElements []map[string]interface{}, anchor baseAnchor) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"file_type": "bitable",
|
||||
"reply_elements": replyElements,
|
||||
"anchor": map[string]interface{}{
|
||||
"block_id": anchor.BlockID,
|
||||
"base_record_id": anchor.BaseRecordID,
|
||||
"base_view_id": anchor.BaseViewID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func anchorBlockIDForDryRun(blockID string) string {
|
||||
if strings.TrimSpace(blockID) != "" {
|
||||
return strings.TrimSpace(blockID)
|
||||
@@ -902,26 +820,6 @@ func anchorBlockIDForDryRun(blockID string) string {
|
||||
return "<anchor_block_id>"
|
||||
}
|
||||
|
||||
func parseBaseCommentAnchor(runtime *common.RuntimeContext) (baseAnchor, error) {
|
||||
blockID := strings.TrimSpace(runtime.Str("block-id"))
|
||||
if blockID == "" {
|
||||
return baseAnchor{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "--block-id is required for base(bitable) record-local comments (format: <table-id>!<record-id>!<view-id>, e.g. tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R)").WithParam("--block-id")
|
||||
}
|
||||
return parseBaseBlockRef(blockID)
|
||||
}
|
||||
|
||||
func parseBaseBlockRef(blockID string) (baseAnchor, error) {
|
||||
parts := strings.Split(strings.TrimSpace(blockID), "!")
|
||||
if len(parts) != 3 || strings.TrimSpace(parts[0]) == "" || strings.TrimSpace(parts[1]) == "" || strings.TrimSpace(parts[2]) == "" {
|
||||
return baseAnchor{}, errs.NewValidationError(errs.SubtypeInvalidArgument, "base(bitable) record-local comments require --block-id in <table-id>!<record-id>!<view-id> format, e.g. tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R").WithParam("--block-id")
|
||||
}
|
||||
return baseAnchor{
|
||||
BlockID: strings.TrimSpace(parts[0]),
|
||||
BaseRecordID: strings.TrimSpace(parts[1]),
|
||||
BaseViewID: strings.TrimSpace(parts[2]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseSlidesBlockRef(blockID string) (string, string, error) {
|
||||
blockID = strings.TrimSpace(blockID)
|
||||
if blockID == "" {
|
||||
@@ -1132,53 +1030,6 @@ func executeSheetComment(runtime *common.RuntimeContext, docRef commentDocRef) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeBaseComment(runtime *common.RuntimeContext, target resolvedCommentTarget) error {
|
||||
replyElements, err := parseCommentReplyElements(runtime.Str("content"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
anchor, err := parseBaseCommentAnchor(runtime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestPath := fmt.Sprintf("/open-apis/drive/v1/files/%s/new_comments", validate.EncodePathSegment(target.FileToken))
|
||||
requestBody := buildBaseCommentCreateV2Request(replyElements, anchor)
|
||||
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Creating base(bitable) record-local comment in %s (table=%s, record=%s, view=%s)\n",
|
||||
common.MaskToken(target.FileToken), anchor.BlockID, anchor.BaseRecordID, anchor.BaseViewID)
|
||||
|
||||
data, err := runtime.CallAPITyped("POST", requestPath, nil, requestBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out := map[string]interface{}{
|
||||
"file_token": target.FileToken,
|
||||
"file_type": "bitable",
|
||||
"resolved_by": target.ResolvedBy,
|
||||
"comment_mode": "base_record",
|
||||
"base_block_id": anchor.BlockID,
|
||||
"base_record_id": anchor.BaseRecordID,
|
||||
"base_view_id": anchor.BaseViewID,
|
||||
}
|
||||
if commentID := data["comment_id"]; commentID != nil {
|
||||
out["comment_id"] = commentID
|
||||
}
|
||||
if replyID := data["reply_id"]; replyID != nil {
|
||||
out["reply_id"] = replyID
|
||||
}
|
||||
if createdAt := firstPresentValue(data, "created_at", "create_time"); createdAt != nil {
|
||||
out["created_at"] = createdAt
|
||||
}
|
||||
if target.WikiToken != "" {
|
||||
out["wiki_token"] = target.WikiToken
|
||||
}
|
||||
|
||||
runtime.Out(out, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeFileComment(runtime *common.RuntimeContext, target resolvedCommentTarget) error {
|
||||
replyElements, err := parseCommentReplyElements(runtime.Str("content"))
|
||||
if err != nil {
|
||||
|
||||
@@ -133,20 +133,6 @@ func TestParseCommentDocRef(t *testing.T) {
|
||||
wantKind: "file",
|
||||
wantToken: "fileToken",
|
||||
},
|
||||
{
|
||||
name: "raw token with type bitable",
|
||||
input: "baseToken",
|
||||
docType: "bitable",
|
||||
wantKind: "base",
|
||||
wantToken: "baseToken",
|
||||
},
|
||||
{
|
||||
name: "raw token with type base alias",
|
||||
input: "baseToken",
|
||||
docType: "base",
|
||||
wantKind: "base",
|
||||
wantToken: "baseToken",
|
||||
},
|
||||
{
|
||||
name: "raw token without type",
|
||||
input: "xxxxxx",
|
||||
@@ -170,18 +156,6 @@ func TestParseCommentDocRef(t *testing.T) {
|
||||
wantKind: "file",
|
||||
wantToken: "boxcn123",
|
||||
},
|
||||
{
|
||||
name: "base url",
|
||||
input: "https://example.larksuite.com/base/baseToken123?table=tbl1",
|
||||
wantKind: "base",
|
||||
wantToken: "baseToken123",
|
||||
},
|
||||
{
|
||||
name: "bitable url",
|
||||
input: "https://example.larksuite.com/bitable/baseToken456?table=tbl1",
|
||||
wantKind: "base",
|
||||
wantToken: "baseToken456",
|
||||
},
|
||||
{
|
||||
name: "unsupported url",
|
||||
input: "https://example.com/not-a-doc",
|
||||
@@ -752,35 +726,6 @@ func TestBuildCommentCreateV2RequestSheetOverridesBlockID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBaseCommentCreateV2Request(t *testing.T) {
|
||||
t.Parallel()
|
||||
replyElements := []map[string]interface{}{
|
||||
{"type": "text", "text": "base comment"},
|
||||
}
|
||||
got := buildBaseCommentCreateV2Request(replyElements, baseAnchor{
|
||||
BlockID: "tbl9mp6fj9kDKHQV",
|
||||
BaseRecordID: "recBIBgGmb",
|
||||
BaseViewID: "vewc46MG1R",
|
||||
})
|
||||
|
||||
if got["file_type"] != "bitable" {
|
||||
t.Fatalf("expected file_type bitable, got %#v", got["file_type"])
|
||||
}
|
||||
anchor, ok := got["anchor"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected anchor map, got %#v", got["anchor"])
|
||||
}
|
||||
if anchor["block_id"] != "tbl9mp6fj9kDKHQV" {
|
||||
t.Fatalf("expected block_id tbl9mp6fj9kDKHQV, got %#v", anchor["block_id"])
|
||||
}
|
||||
if anchor["base_record_id"] != "recBIBgGmb" {
|
||||
t.Fatalf("expected base_record_id recBIBgGmb, got %#v", anchor["base_record_id"])
|
||||
}
|
||||
if anchor["base_view_id"] != "vewc46MG1R" {
|
||||
t.Fatalf("expected base_view_id vewc46MG1R, got %#v", anchor["base_view_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// ── Sheet cell ref parsing tests ────────────────────────────────────────────
|
||||
|
||||
func TestParseSheetCellRef(t *testing.T) {
|
||||
@@ -1040,78 +985,6 @@ func TestFileCommentValidateRejectsSelectionWithEllipsis(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseCommentValidateMissingBlockID(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/base/baseToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "--block-id is required") {
|
||||
t.Fatalf("expected block-id required error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseCommentValidateMalformedBlockID(t *testing.T) {
|
||||
cases := []string{
|
||||
"tbl9mp6fj9kDKHQV",
|
||||
"tbl9mp6fj9kDKHQV!recBIBgGmb",
|
||||
"tbl9mp6fj9kDKHQV!!vewc46MG1R",
|
||||
}
|
||||
for _, blockID := range cases {
|
||||
t.Run(blockID, func(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/base/baseToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--block-id", blockID,
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "<table-id>!<record-id>!<view-id>") {
|
||||
t.Fatalf("expected block-id format error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseCommentValidateRejectsIncompatibleFlags(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "full comment",
|
||||
args: []string{"--full-comment"},
|
||||
wantErr: "--full-comment is not applicable for base(bitable) comments",
|
||||
},
|
||||
{
|
||||
name: "selection",
|
||||
args: []string{"--selection-with-ellipsis", "some text"},
|
||||
wantErr: "--selection-with-ellipsis is not applicable for base(bitable) comments",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
args := []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/base/baseToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--block-id", "tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R",
|
||||
"--as", "user",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
err := mountAndRunDrive(t, DriveAddComment, args, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
|
||||
t.Fatalf("expected %q error, got: %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── Slides comment execute tests ────────────────────────────────────────────
|
||||
|
||||
func TestSlidesCommentExecuteSuccess(t *testing.T) {
|
||||
@@ -1322,87 +1195,6 @@ func TestSheetCommentViaWikiMissingBlockID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseCommentExecuteSuccess(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
createStub := &httpmock.Stub{
|
||||
Method: "POST", URL: "/open-apis/drive/v1/files/baseToken/new_comments",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"comment_id": "baseComment123",
|
||||
"reply_id": "baseReply123",
|
||||
"created_at": 1700000000,
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(createStub)
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/base/baseToken",
|
||||
"--content", `[{"type":"text","text":"请看这条记录"}]`,
|
||||
"--block-id", "tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var requestBody map[string]interface{}
|
||||
if err := json.Unmarshal(createStub.CapturedBody, &requestBody); err != nil {
|
||||
t.Fatalf("failed to decode captured body: %v\nbody:\n%s", err, string(createStub.CapturedBody))
|
||||
}
|
||||
if got := mustStringField(t, requestBody, "file_type", "request.file_type"); got != "bitable" {
|
||||
t.Fatalf("request file_type = %q, want bitable", got)
|
||||
}
|
||||
anchor := mustMapValue(t, requestBody["anchor"], "request.anchor")
|
||||
if got := mustStringField(t, anchor, "block_id", "request.anchor.block_id"); got != "tbl9mp6fj9kDKHQV" {
|
||||
t.Fatalf("request block_id = %q, want tbl9mp6fj9kDKHQV", got)
|
||||
}
|
||||
if got := mustStringField(t, anchor, "base_record_id", "request.anchor.base_record_id"); got != "recBIBgGmb" {
|
||||
t.Fatalf("request base_record_id = %q, want recBIBgGmb", got)
|
||||
}
|
||||
if got := mustStringField(t, anchor, "base_view_id", "request.anchor.base_view_id"); got != "vewc46MG1R" {
|
||||
t.Fatalf("request base_view_id = %q, want vewc46MG1R", got)
|
||||
}
|
||||
|
||||
out := decodeJSONMap(t, stdout.String())
|
||||
data := mustMapValue(t, out["data"], "data")
|
||||
if got := mustStringField(t, data, "file_type", "data.file_type"); got != "bitable" {
|
||||
t.Fatalf("stdout file_type = %q, want bitable\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, data, "comment_mode", "data.comment_mode"); got != "base_record" {
|
||||
t.Fatalf("stdout comment_mode = %q, want base_record\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, data, "reply_id", "data.reply_id"); got != "baseReply123" {
|
||||
t.Fatalf("stdout reply_id = %q, want baseReply123\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseCommentExecuteBareToken(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST", URL: "/open-apis/drive/v1/files/baseBareToken/new_comments",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{"comment_id": "baseBareComment"},
|
||||
},
|
||||
})
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "baseBareToken",
|
||||
"--type", "bitable",
|
||||
"--content", `[{"type":"text","text":"ok"}]`,
|
||||
"--block-id", "tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(stdout.String(), "baseBareComment") {
|
||||
t.Fatalf("stdout missing comment_id: %s", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCommentExecuteSuccess(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
@@ -1641,40 +1433,6 @@ func TestDryRunSlidesDirectURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDryRunBaseDirectURL(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/base/baseToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--block-id", "tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R",
|
||||
"--dry-run", "--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(stdout.String(), "record-local comment") {
|
||||
t.Fatalf("dry-run output missing record-local comment: %s", stdout.String())
|
||||
}
|
||||
out := decodeJSONMap(t, stdout.String())
|
||||
api := mustSliceValue(t, out["api"], "api")
|
||||
call := mustMapValue(t, api[0], "api[0]")
|
||||
body := mustMapValue(t, call["body"], "api[0].body")
|
||||
anchor := mustMapValue(t, body["anchor"], "api[0].body.anchor")
|
||||
if got := mustStringField(t, body, "file_type", "api[0].body.file_type"); got != "bitable" {
|
||||
t.Fatalf("dry-run body.file_type = %q, want bitable\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, anchor, "block_id", "api[0].body.anchor.block_id"); got != "tbl9mp6fj9kDKHQV" {
|
||||
t.Fatalf("dry-run body.anchor.block_id = %q, want tbl9mp6fj9kDKHQV\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, anchor, "base_record_id", "api[0].body.anchor.base_record_id"); got != "recBIBgGmb" {
|
||||
t.Fatalf("dry-run body.anchor.base_record_id = %q, want recBIBgGmb\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, anchor, "base_view_id", "api[0].body.anchor.base_view_id"); got != "vewc46MG1R" {
|
||||
t.Fatalf("dry-run body.anchor.base_view_id = %q, want vewc46MG1R\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDryRunWikiResolvesToSlides(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
@@ -1878,92 +1636,25 @@ func TestResolveWikiToDocxFullComment(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWikiToBaseComment(t *testing.T) {
|
||||
for _, objType := range []string{"bitable", "base"} {
|
||||
t.Run(objType, func(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET", URL: "/open-apis/wiki/v2/spaces/get_node",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"node": map[string]interface{}{"obj_type": objType, "obj_token": "bitToken"},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST", URL: "/open-apis/drive/v1/files/bitToken/new_comments",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{"comment_id": "wikiBaseComment", "reply_id": "wikiBaseReply"},
|
||||
},
|
||||
})
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/wiki/wikiToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--block-id", "tbl9mp6fj9kDKHQV!recBIBgGmb!vewc46MG1R",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(stdout.String(), "wikiBaseComment") {
|
||||
t.Fatalf("stdout missing comment_id: %s", stdout.String())
|
||||
}
|
||||
out := decodeJSONMap(t, stdout.String())
|
||||
data := mustMapValue(t, out["data"], "data")
|
||||
if got := mustStringField(t, data, "file_type", "data.file_type"); got != "bitable" {
|
||||
t.Fatalf("stdout file_type = %q, want bitable\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
if got := mustStringField(t, data, "wiki_token", "data.wiki_token"); got != "wikiToken" {
|
||||
t.Fatalf("stdout wiki_token = %q, want wikiToken\nstdout:\n%s", got, stdout.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWikiToBaseRejectsIncompatibleFlags(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "full comment",
|
||||
args: []string{"--full-comment"},
|
||||
wantErr: "--full-comment is not applicable for base(bitable) comments",
|
||||
func TestResolveWikiToUnsupportedType(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET", URL: "/open-apis/wiki/v2/spaces/get_node",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"node": map[string]interface{}{"obj_type": "bitable", "obj_token": "bitToken"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "selection",
|
||||
args: []string{"--selection-with-ellipsis", "some text"},
|
||||
wantErr: "--selection-with-ellipsis is not applicable for base(bitable) comments",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET", URL: "/open-apis/wiki/v2/spaces/get_node",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"node": map[string]interface{}{"obj_type": "bitable", "obj_token": "bitToken"},
|
||||
},
|
||||
},
|
||||
})
|
||||
args := []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/wiki/wikiToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--as", "user",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
err := mountAndRunDrive(t, DriveAddComment, args, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
|
||||
t.Fatalf("expected %q error, got: %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
})
|
||||
err := mountAndRunDrive(t, DriveAddComment, []string{
|
||||
"+add-comment",
|
||||
"--doc", "https://example.larksuite.com/wiki/wikiToken",
|
||||
"--content", `[{"type":"text","text":"test"}]`,
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "only support doc/docx/file/sheet/slides") {
|
||||
t.Fatalf("expected unsupported type error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2044,7 +1735,7 @@ func TestDocOldFormatLocalCommentRejected(t *testing.T) {
|
||||
"--block-id", "blk_123",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "only support docx, sheet, slides, and base(bitable)") {
|
||||
if err == nil || !strings.Contains(err.Error(), "only support docx, sheet, and slides") {
|
||||
t.Fatalf("expected local comment rejection for old doc, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,9 +149,6 @@ var DriveSearch = common.Shortcut{
|
||||
"page_token": data["page_token"],
|
||||
"results": normalizedItems,
|
||||
}
|
||||
if notice, _ := data["notice"].(string); notice != "" {
|
||||
resultData["notice"] = notice
|
||||
}
|
||||
|
||||
runtime.OutFormat(resultData, &output.Meta{Count: len(normalizedItems)}, func(w io.Writer) {
|
||||
renderDriveSearchTable(w, data, normalizedItems)
|
||||
|
||||
@@ -14,49 +14,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/errclass"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
)
|
||||
|
||||
// TestDriveSearchExecutePassesThroughNotice verifies drive +search preserves notices.
|
||||
func TestDriveSearchExecutePassesThroughNotice(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/search/v2/doc_wiki/search",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"res_units": []interface{}{},
|
||||
"total": 0,
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := mountAndRunDrive(t, DriveSearch, []string{"+search", "--query", "incident", "--format", "json", "--as", "user"}, f, stdout); err != nil {
|
||||
t.Fatalf("DriveSearch.Execute() error = %v", err)
|
||||
}
|
||||
reg.Verify(t)
|
||||
|
||||
var env map[string]interface{}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &env); err != nil {
|
||||
t.Fatalf("json.Unmarshal(stdout) error = %v\nstdout=%s", err, stdout.String())
|
||||
}
|
||||
data, _ := env["data"].(map[string]interface{})
|
||||
if got, _ := data["notice"].(string); got != notice {
|
||||
t.Fatalf("data.notice = %q, want %q; data=%#v", got, notice, data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClampOpenedTimeWindow covers opened-time clamping and slice notices.
|
||||
// TestClampOpenedTimeWindow covers the 3-month / 1-year boundary logic that
|
||||
// narrows --opened-since / --opened-until and generates the multi-slice notice.
|
||||
func TestClampOpenedTimeWindow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -26,7 +26,9 @@ func mustMarshalDryRun(t *testing.T, v interface{}) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// newTestRuntimeContext builds a RuntimeContext with string and bool test flags.
|
||||
// newTestRuntimeContext builds a *common.RuntimeContext backed by a cobra
|
||||
// command whose flags are populated from the provided string and bool maps,
|
||||
// for unit-testing shortcut bodies, validators, and dry-run shapes.
|
||||
func newTestRuntimeContext(t *testing.T, stringFlags map[string]string, boolFlags map[string]bool) *common.RuntimeContext {
|
||||
t.Helper()
|
||||
|
||||
@@ -57,38 +59,9 @@ func newTestRuntimeContext(t *testing.T, stringFlags map[string]string, boolFlag
|
||||
return &common.RuntimeContext{Cmd: cmd}
|
||||
}
|
||||
|
||||
// newChatSearchTestRuntimeContext builds a chat-search RuntimeContext with typed flags.
|
||||
func newChatSearchTestRuntimeContext(t *testing.T, stringFlags map[string]string, boolFlags map[string]bool) *common.RuntimeContext {
|
||||
t.Helper()
|
||||
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
cmd.Flags().Int("page-size", 20, "")
|
||||
for name := range stringFlags {
|
||||
if name == "page-size" {
|
||||
continue
|
||||
}
|
||||
cmd.Flags().String(name, "", "")
|
||||
}
|
||||
for name := range boolFlags {
|
||||
cmd.Flags().Bool(name, false, "")
|
||||
}
|
||||
if err := cmd.ParseFlags(nil); err != nil {
|
||||
t.Fatalf("ParseFlags() error = %v", err)
|
||||
}
|
||||
for name, val := range stringFlags {
|
||||
if err := cmd.Flags().Set(name, val); err != nil {
|
||||
t.Fatalf("Flags().Set(%q) error = %v", name, err)
|
||||
}
|
||||
}
|
||||
for name, val := range boolFlags {
|
||||
if err := cmd.Flags().Set(name, map[bool]string{true: "true", false: "false"}[val]); err != nil {
|
||||
t.Fatalf("Flags().Set(%q) error = %v", name, err)
|
||||
}
|
||||
}
|
||||
return &common.RuntimeContext{Cmd: cmd}
|
||||
}
|
||||
|
||||
// newMessagesSearchTestRuntimeContext builds a messages-search RuntimeContext.
|
||||
// newMessagesSearchTestRuntimeContext is the messages-search variant of
|
||||
// newTestRuntimeContext: registers the search-specific --page-size flag
|
||||
// before applying caller-provided values.
|
||||
func newMessagesSearchTestRuntimeContext(t *testing.T, stringFlags map[string]string, boolFlags map[string]bool) *common.RuntimeContext {
|
||||
t.Helper()
|
||||
|
||||
@@ -258,7 +231,6 @@ func TestIsMediaKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortcutValidateBranches covers direct shortcut validation branches.
|
||||
func TestShortcutValidateBranches(t *testing.T) {
|
||||
|
||||
t.Run("ImChatCreate valid", func(t *testing.T) {
|
||||
@@ -325,7 +297,7 @@ func TestShortcutValidateBranches(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ImChatSearch invalid page size", func(t *testing.T) {
|
||||
runtime := newChatSearchTestRuntimeContext(t, map[string]string{
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"query": "ok",
|
||||
"page-size": "0",
|
||||
}, nil)
|
||||
@@ -335,13 +307,12 @@ func TestShortcutValidateBranches(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImChatSearch allows long query for server-side notice", func(t *testing.T) {
|
||||
runtime := newChatSearchTestRuntimeContext(t, map[string]string{
|
||||
"query": strings.Repeat("q", 81),
|
||||
"page-size": "20",
|
||||
t.Run("ImChatSearch query too long", func(t *testing.T) {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"query": strings.Repeat("q", 65),
|
||||
}, nil)
|
||||
err := ImChatSearch.Validate(context.Background(), runtime)
|
||||
if err != nil {
|
||||
if err == nil || !strings.Contains(err.Error(), "--query exceeds the maximum of 64 characters") {
|
||||
t.Fatalf("ImChatSearch.Validate() error = %v", err)
|
||||
}
|
||||
})
|
||||
@@ -469,29 +440,6 @@ func TestShortcutValidateBranches(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImMessagesSend audio rejects non-opus local file", func(t *testing.T) {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"chat-id": "oc_123",
|
||||
"audio": "./voice.mp3",
|
||||
}, nil)
|
||||
err := ImMessagesSend.Validate(context.Background(), runtime)
|
||||
if err == nil || !strings.Contains(err.Error(), "--audio supports only Opus audio files") {
|
||||
t.Fatalf("ImMessagesSend.Validate() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImMessagesSend audio accepts opus and ogg local files", func(t *testing.T) {
|
||||
for _, audio := range []string{"./voice.opus", "./voice.ogg"} {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"chat-id": "oc_123",
|
||||
"audio": audio,
|
||||
}, nil)
|
||||
if err := ImMessagesSend.Validate(context.Background(), runtime); err != nil {
|
||||
t.Fatalf("ImMessagesSend.Validate(%q) unexpected error = %v", audio, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImMessagesSend conflicting explicit msg-type", func(t *testing.T) {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"chat-id": "oc_123",
|
||||
@@ -515,17 +463,6 @@ func TestShortcutValidateBranches(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImMessagesReply audio rejects non-opus local file", func(t *testing.T) {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"message-id": "om_123",
|
||||
"audio": "./voice.mp3",
|
||||
}, nil)
|
||||
err := ImMessagesReply.Validate(context.Background(), runtime)
|
||||
if err == nil || !strings.Contains(err.Error(), "--audio supports only Opus audio files") {
|
||||
t.Fatalf("ImMessagesReply.Validate() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImThreadsMessagesList invalid thread", func(t *testing.T) {
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"thread": "bad_thread",
|
||||
@@ -670,7 +607,6 @@ func TestShortcutValidateBranches(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestMessagesSearchPaginationConfig verifies page-all and page-limit behavior.
|
||||
func TestMessagesSearchPaginationConfig(t *testing.T) {
|
||||
t.Run("default single page", func(t *testing.T) {
|
||||
runtime := newMessagesSearchTestRuntimeContext(t, nil, nil)
|
||||
@@ -714,7 +650,8 @@ func TestMessagesSearchPaginationConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestShortcutDryRunShapes verifies shortcut dry-run API paths and payloads.
|
||||
// TestShortcutDryRunShapes verifies that each shortcut's DryRun function
|
||||
// produces the expected API path, query parameters, and request body.
|
||||
func TestShortcutDryRunShapes(t *testing.T) {
|
||||
t.Run("ImChatCreate dry run includes params and body", func(t *testing.T) {
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
@@ -737,19 +674,19 @@ func TestShortcutDryRunShapes(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ImChatSearch dry run includes built params", func(t *testing.T) {
|
||||
runtime := newChatSearchTestRuntimeContext(t, map[string]string{
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"query": "team-alpha",
|
||||
"page-size": "50",
|
||||
"page-token": "next_page",
|
||||
}, nil)
|
||||
got := mustMarshalDryRun(t, ImChatSearch.DryRun(context.Background(), runtime))
|
||||
if !strings.Contains(got, `"/open-apis/im/v2/chats/search"`) || !strings.Contains(got, `"page_size":50`) || !strings.Contains(got, `"query":"\"team-alpha\""`) {
|
||||
if !strings.Contains(got, `"/open-apis/im/v2/chats/search"`) || !strings.Contains(got, `"page_size":20`) || !strings.Contains(got, `"query":"\"team-alpha\""`) {
|
||||
t.Fatalf("ImChatSearch.DryRun() = %s", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImChatSearch dry run still works with --exclude-muted set", func(t *testing.T) {
|
||||
runtime := newChatSearchTestRuntimeContext(t, map[string]string{
|
||||
runtime := newTestRuntimeContext(t, map[string]string{
|
||||
"query": "team-alpha",
|
||||
}, map[string]bool{
|
||||
"exclude-muted": true,
|
||||
|
||||
@@ -1048,42 +1048,6 @@ func detectIMFileType(filePath string) string {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
audioMessageInputDesc = "audio file key (file_xxx), URL, or cwd-relative local path for a voice message (absolute paths and .. are rejected); local paths and URLs must be Opus (.opus or Ogg Opus .ogg). For mp3/wav, convert to .opus first, or use --file to send as an attachment"
|
||||
audioMessageHint = "Convert non-Opus audio to .opus and use --audio for a voice message, for example: ffmpeg -i input.mp3 -acodec libopus -ac 1 -ar 16000 output.opus. To send the original mp3/wav as an attachment, use --file instead."
|
||||
)
|
||||
|
||||
func validateAudioMessageInput(flagName, value string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || isMediaKey(value) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ext := audioInputExt(value)
|
||||
if ext == "" {
|
||||
return nil
|
||||
}
|
||||
if ext == ".opus" || ext == ".ogg" {
|
||||
return nil
|
||||
}
|
||||
return errs.NewValidationError(
|
||||
errs.SubtypeInvalidArgument,
|
||||
"%s supports only Opus audio files for audio messages, such as .opus files or Ogg Opus (.ogg) files",
|
||||
flagName,
|
||||
).WithParam(flagName).WithHint("%s", audioMessageHint)
|
||||
}
|
||||
|
||||
func audioInputExt(value string) string {
|
||||
if isURL(value) {
|
||||
parsed, err := url.Parse(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(path.Ext(parsed.Path))
|
||||
}
|
||||
return strings.ToLower(filepath.Ext(value))
|
||||
}
|
||||
|
||||
const maxImageUploadSize = 5 * 1024 * 1024 // 5MB — Lark API limit for images
|
||||
const maxFileUploadSize = 100 * 1024 * 1024 // 100MB — Lark API limit for files
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
@@ -51,56 +50,6 @@ func TestDetectIMFileType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAudioMessageInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty", value: ""},
|
||||
{name: "existing file key", value: "file_abc"},
|
||||
{name: "opus file", value: "./voice.opus"},
|
||||
{name: "ogg opus file", value: "./voice.ogg"},
|
||||
{name: "uppercase opus", value: "./VOICE.OPUS"},
|
||||
{name: "mp3 local file", value: "./voice.mp3", wantErr: true},
|
||||
{name: "wav local file", value: "./voice.wav", wantErr: true},
|
||||
{name: "extensionless local path", value: "./voice"},
|
||||
{name: "opus url", value: "https://example.com/voice.opus?download=1"},
|
||||
{name: "ogg url", value: "https://example.com/voice.ogg?download=1"},
|
||||
{name: "mp3 url", value: "https://example.com/voice.mp3?download=1", wantErr: true},
|
||||
{name: "extensionless url", value: "https://example.com/download?id=1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateAudioMessageInput("--audio", tt.value)
|
||||
if tt.wantErr {
|
||||
if err == nil || !strings.Contains(err.Error(), "--audio supports only Opus audio files") {
|
||||
t.Fatalf("validateAudioMessageInput(%q) error = %v", tt.value, err)
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("validateAudioMessageInput(%q) error is not typed: %v", tt.value, err)
|
||||
}
|
||||
if p.Category != errs.CategoryValidation || p.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Fatalf("ProblemOf(%q) = category %q subtype %q", tt.value, p.Category, p.Subtype)
|
||||
}
|
||||
validationErr, ok := err.(*errs.ValidationError)
|
||||
if !ok || validationErr.Param != "--audio" {
|
||||
t.Fatalf("validateAudioMessageInput(%q) param = %q, want --audio", tt.value, validationErr.Param)
|
||||
}
|
||||
if !strings.Contains(p.Hint, "use --file") || !strings.Contains(p.Hint, "ffmpeg") {
|
||||
t.Fatalf("validateAudioMessageInput(%q) hint = %q, want --file and ffmpeg guidance", tt.value, p.Hint)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("validateAudioMessageInput(%q) unexpected error = %v", tt.value, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitCSV covers the shared helper that replaced the three identical functions
|
||||
func TestSplitCSV(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -29,7 +29,7 @@ var ImChatSearch = common.Shortcut{
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
HasFormat: true,
|
||||
Flags: []common.Flag{
|
||||
{Name: "query", Desc: "search keyword (server may return data.notice for overly long input)"},
|
||||
{Name: "query", Desc: "search keyword (max 64 chars)"},
|
||||
{Name: "search-types", Desc: "chat types, comma-separated (private, external, public_joined, public_not_joined)"},
|
||||
{Name: "chat-modes", Desc: "filter by chat mode, comma-separated (group, topic)"},
|
||||
{Name: "member-ids", Desc: "filter by member open_ids, comma-separated"},
|
||||
@@ -50,7 +50,7 @@ var ImChatSearch = common.Shortcut{
|
||||
Params(params).
|
||||
Body(body)
|
||||
},
|
||||
// Validate enforces query/member-ids presence, search-types
|
||||
// Validate enforces query/member-ids presence, --query rune cap, search-types
|
||||
// enum, --member-ids count and format, and --page-size bounds.
|
||||
Validate: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
query := runtime.Str("query")
|
||||
@@ -58,6 +58,9 @@ var ImChatSearch = common.Shortcut{
|
||||
if query == "" && memberIDs == "" {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--query and --member-ids cannot both be empty; provide at least one (e.g. --query \"team-name\" or --member-ids \"ou_xxx\")")
|
||||
}
|
||||
if query != "" && len([]rune(query)) > 64 {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--query exceeds the maximum of 64 characters (got %d)", len([]rune(query))).WithParam("--query")
|
||||
}
|
||||
if st := runtime.Str("search-types"); st != "" {
|
||||
allowed := map[string]struct{}{
|
||||
"private": {},
|
||||
@@ -148,9 +151,6 @@ var ImChatSearch = common.Shortcut{
|
||||
"has_more": hasMore,
|
||||
"page_token": pageToken,
|
||||
}
|
||||
if notice, _ := resData["notice"].(string); notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
if mfOut.Meta.Applied != "" {
|
||||
outData["filter"] = MuteFilterMetaToMap(mfOut.Meta)
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ var ImMessagesReply = common.Shortcut{
|
||||
{Name: "file", Desc: "file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected)"},
|
||||
{Name: "video", Desc: "video file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected); must be used together with --video-cover"},
|
||||
{Name: "video-cover", Desc: "video cover image key (img_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected); required when using --video"},
|
||||
{Name: "audio", Desc: audioMessageInputDesc},
|
||||
{Name: "audio", Desc: "audio file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected)"},
|
||||
{Name: "reply-in-thread", Type: "bool", Desc: "reply in thread (message appears in thread stream instead of main chat)"},
|
||||
{Name: "idempotency-key", Desc: "idempotency key (prevents duplicate sends)"},
|
||||
},
|
||||
@@ -100,9 +100,6 @@ var ImMessagesReply = common.Shortcut{
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := validateAudioMessageInput("--audio", audioKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if messageId == "" {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--message-id is required (om_xxx)").WithParam("--message-id")
|
||||
|
||||
@@ -91,7 +91,7 @@ var ImMessagesSearch = common.Shortcut{
|
||||
return err
|
||||
}
|
||||
|
||||
rawItems, hasMore, nextPageToken, truncatedByLimit, pageLimit, notice, err := searchMessages(runtime, req)
|
||||
rawItems, hasMore, nextPageToken, truncatedByLimit, pageLimit, err := searchMessages(runtime, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -103,9 +103,6 @@ var ImMessagesSearch = common.Shortcut{
|
||||
"has_more": hasMore,
|
||||
"page_token": nextPageToken,
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
runtime.OutFormat(outData, nil, func(w io.Writer) {
|
||||
fmt.Fprintln(w, "No matching messages found.")
|
||||
})
|
||||
@@ -134,9 +131,6 @@ var ImMessagesSearch = common.Shortcut{
|
||||
"page_token": nextPageToken,
|
||||
"note": "failed to fetch message details, returning ID list only",
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
runtime.OutFormat(outData, nil, func(w io.Writer) {
|
||||
fmt.Fprintf(w, "Found %d messages (failed to fetch details):\n", len(messageIds))
|
||||
for _, id := range messageIds {
|
||||
@@ -212,9 +206,6 @@ var ImMessagesSearch = common.Shortcut{
|
||||
"has_more": hasMore,
|
||||
"page_token": nextPageToken,
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
runtime.OutFormat(outData, nil, func(w io.Writer) {
|
||||
if len(enriched) == 0 {
|
||||
fmt.Fprintln(w, "No matching messages found.")
|
||||
@@ -386,7 +377,6 @@ func buildMessagesSearchRequest(runtime *common.RuntimeContext) (*messagesSearch
|
||||
}, nil
|
||||
}
|
||||
|
||||
// messagesSearchPaginationConfig derives auto-pagination mode and page limit.
|
||||
func messagesSearchPaginationConfig(runtime *common.RuntimeContext) (autoPaginate bool, pageLimit int) {
|
||||
autoPaginate = runtime.Bool("page-all")
|
||||
if runtime.Cmd != nil && runtime.Cmd.Flags().Changed("page-limit") {
|
||||
@@ -402,8 +392,7 @@ func messagesSearchPaginationConfig(runtime *common.RuntimeContext) (autoPaginat
|
||||
return autoPaginate, pageLimit
|
||||
}
|
||||
|
||||
// searchMessages fetches message search pages and returns the first server notice.
|
||||
func searchMessages(runtime *common.RuntimeContext, req *messagesSearchRequest) ([]interface{}, bool, string, bool, int, string, error) {
|
||||
func searchMessages(runtime *common.RuntimeContext, req *messagesSearchRequest) ([]interface{}, bool, string, bool, int, error) {
|
||||
autoPaginate, pageLimit := messagesSearchPaginationConfig(runtime)
|
||||
pageToken := ""
|
||||
if tokens := req.params["page_token"]; len(tokens) > 0 {
|
||||
@@ -421,7 +410,6 @@ func searchMessages(runtime *common.RuntimeContext, req *messagesSearchRequest)
|
||||
lastPageToken string
|
||||
truncatedByLimit bool
|
||||
pageCount int
|
||||
notice string
|
||||
)
|
||||
|
||||
for {
|
||||
@@ -435,12 +423,9 @@ func searchMessages(runtime *common.RuntimeContext, req *messagesSearchRequest)
|
||||
|
||||
searchData, err := runtime.DoAPIJSONTyped(http.MethodPost, "/open-apis/im/v1/messages/search", params, req.body)
|
||||
if err != nil {
|
||||
return nil, false, "", false, pageLimit, "", err
|
||||
return nil, false, "", false, pageLimit, err
|
||||
}
|
||||
|
||||
if notice == "" {
|
||||
notice, _ = searchData["notice"].(string)
|
||||
}
|
||||
items, _ := searchData["items"].([]interface{})
|
||||
allItems = append(allItems, items...)
|
||||
lastHasMore, lastPageToken = common.PaginationMeta(searchData)
|
||||
@@ -456,10 +441,9 @@ func searchMessages(runtime *common.RuntimeContext, req *messagesSearchRequest)
|
||||
pageToken = lastPageToken
|
||||
}
|
||||
|
||||
return allItems, lastHasMore, lastPageToken, truncatedByLimit, pageLimit, notice, nil
|
||||
return allItems, lastHasMore, lastPageToken, truncatedByLimit, pageLimit, nil
|
||||
}
|
||||
|
||||
// batchMGetMessages fetches message details in API-sized batches.
|
||||
func batchMGetMessages(runtime *common.RuntimeContext, messageIds []string) ([]interface{}, error) {
|
||||
var items []interface{}
|
||||
for _, batch := range chunkStrings(messageIds, messagesSearchMGetBatchSize) {
|
||||
@@ -473,7 +457,6 @@ func batchMGetMessages(runtime *common.RuntimeContext, messageIds []string) ([]i
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// batchQueryChatContexts fetches chat metadata best-effort for message rows.
|
||||
func batchQueryChatContexts(runtime *common.RuntimeContext, chatIds []string) map[string]map[string]interface{} {
|
||||
chatContexts := map[string]map[string]interface{}{}
|
||||
// Best-effort: a failed chunk only loses its own entries.
|
||||
@@ -483,7 +466,6 @@ func batchQueryChatContexts(runtime *common.RuntimeContext, chatIds []string) ma
|
||||
return chatContexts
|
||||
}
|
||||
|
||||
// chunkStrings splits a string slice into fixed-size batches.
|
||||
func chunkStrings(items []string, chunkSize int) [][]string {
|
||||
if len(items) == 0 || chunkSize <= 0 {
|
||||
return nil
|
||||
|
||||
@@ -37,7 +37,7 @@ var ImMessagesSend = common.Shortcut{
|
||||
{Name: "file", Desc: "file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected)"},
|
||||
{Name: "video", Desc: "video file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected); must be used together with --video-cover"},
|
||||
{Name: "video-cover", Desc: "video cover image key (img_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected); required when using --video"},
|
||||
{Name: "audio", Desc: audioMessageInputDesc},
|
||||
{Name: "audio", Desc: "audio file key (file_xxx), URL, or cwd-relative local path (absolute paths and .. are rejected)"},
|
||||
},
|
||||
DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
chatFlag := runtime.Str("chat-id")
|
||||
@@ -112,9 +112,6 @@ var ImMessagesSend = common.Shortcut{
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := validateAudioMessageInput("--audio", audioKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := common.ExactlyOneTyped(runtime, "chat-id", "user-id"); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package im
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TestImChatSearchExecutePassesThroughNotice verifies chat search notice output.
|
||||
func TestImChatSearchExecutePassesThroughNotice(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
longQuery := strings.Repeat("q", 81)
|
||||
|
||||
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if !strings.Contains(req.URL.Path, "/open-apis/im/v2/chats/search") {
|
||||
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return nil, fmt.Errorf("decode request body: %w", err)
|
||||
}
|
||||
if got, _ := body["query"].(string); got != longQuery {
|
||||
return nil, fmt.Errorf("body.query = %q, want %q", got, longQuery)
|
||||
}
|
||||
return shortcutJSONResponse(200, map[string]interface{}{
|
||||
"code": 0,
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"items": []interface{}{},
|
||||
"total": 0,
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
},
|
||||
}), nil
|
||||
}))
|
||||
runtime.Cmd = newChatSearchNoticeTestCommand(t, longQuery)
|
||||
runtime.Format = "json"
|
||||
|
||||
if err := ImChatSearch.Execute(context.Background(), runtime); err != nil {
|
||||
t.Fatalf("ImChatSearch.Execute() error = %v", err)
|
||||
}
|
||||
|
||||
data := decodeShortcutData(t, runtime)
|
||||
if got, _ := data["notice"].(string); got != notice {
|
||||
t.Fatalf("data.notice = %q, want %q; data=%#v", got, notice, data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestImMessagesSearchExecutePassesThroughNotice verifies message search notice output.
|
||||
func TestImMessagesSearchExecutePassesThroughNotice(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
runtime := newMessagesSearchRuntime(t, map[string]string{
|
||||
"query": "incident",
|
||||
}, nil, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if !strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/search") {
|
||||
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
|
||||
}
|
||||
return shortcutJSONResponse(200, map[string]interface{}{
|
||||
"code": 0,
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"items": []interface{}{},
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
},
|
||||
}), nil
|
||||
}))
|
||||
runtime.Format = "json"
|
||||
|
||||
if err := ImMessagesSearch.Execute(context.Background(), runtime); err != nil {
|
||||
t.Fatalf("ImMessagesSearch.Execute() error = %v", err)
|
||||
}
|
||||
|
||||
data := decodeShortcutData(t, runtime)
|
||||
if got, _ := data["notice"].(string); got != notice {
|
||||
t.Fatalf("data.notice = %q, want %q; data=%#v", got, notice, data)
|
||||
}
|
||||
}
|
||||
|
||||
// newChatSearchNoticeTestCommand builds a typed chat-search command for notice tests.
|
||||
func newChatSearchNoticeTestCommand(t *testing.T, query string) *cobra.Command {
|
||||
t.Helper()
|
||||
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
for _, name := range []string{"query", "search-types", "member-ids", "sort-by", "page-token"} {
|
||||
cmd.Flags().String(name, "", "")
|
||||
}
|
||||
for _, name := range []string{"is-manager", "disable-search-by-user", "exclude-muted"} {
|
||||
cmd.Flags().Bool(name, false, "")
|
||||
}
|
||||
cmd.Flags().Int("page-size", 20, "")
|
||||
if err := cmd.ParseFlags(nil); err != nil {
|
||||
t.Fatalf("ParseFlags() error = %v", err)
|
||||
}
|
||||
if err := cmd.Flags().Set("query", query); err != nil {
|
||||
t.Fatalf("Flags().Set(query) error = %v", err)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// decodeShortcutData extracts the JSON envelope data object from shortcut output.
|
||||
func decodeShortcutData(t *testing.T, runtime *common.RuntimeContext) map[string]interface{} {
|
||||
t.Helper()
|
||||
|
||||
out, ok := runtime.Factory.IOStreams.Out.(*bytes.Buffer)
|
||||
if !ok {
|
||||
t.Fatalf("stdout buffer has type %T", runtime.Factory.IOStreams.Out)
|
||||
}
|
||||
var env map[string]interface{}
|
||||
if err := json.Unmarshal(out.Bytes(), &env); err != nil {
|
||||
t.Fatalf("json.Unmarshal(stdout) error = %v\nstdout=%s", err, out.String())
|
||||
}
|
||||
data, ok := env["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("envelope data missing or wrong type: %#v", env)
|
||||
}
|
||||
return data
|
||||
}
|
||||
@@ -159,7 +159,6 @@ var MailTriage = common.Shortcut{
|
||||
var messages []map[string]interface{}
|
||||
var hasMore bool
|
||||
var nextPageToken string
|
||||
var notice string
|
||||
|
||||
useSearch, err := resolveTriagePath(parsed, query, filter)
|
||||
if err != nil {
|
||||
@@ -190,9 +189,6 @@ var MailTriage = common.Shortcut{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notice == "" {
|
||||
notice, _ = searchData["notice"].(string)
|
||||
}
|
||||
pageMessages := buildTriageMessagesFromSearchItems(searchData["items"])
|
||||
messages = append(messages, pageMessages...)
|
||||
pageHasMore, _ := searchData["has_more"].(bool)
|
||||
@@ -286,14 +282,8 @@ var MailTriage = common.Shortcut{
|
||||
"has_more": hasMore,
|
||||
"page_token": nextPageToken,
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
output.PrintJson(runtime.IO().Out, outData)
|
||||
default: // "table"
|
||||
if notice != "" {
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "notice: %s\n", notice)
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
fmt.Fprintln(runtime.IO().ErrOut, "No messages found.")
|
||||
return nil
|
||||
@@ -770,7 +760,13 @@ func buildListParams(runtime *common.RuntimeContext, mailboxID string, f triageF
|
||||
params["folder_id"] = folderIDFromFilter
|
||||
}
|
||||
} else {
|
||||
params["folder_id"] = folderIDFromFilter
|
||||
resolved, err := resolveFolderID(runtime, mailboxID, folderIDFromFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resolved != "" {
|
||||
params["folder_id"] = resolved
|
||||
}
|
||||
}
|
||||
} else if folderFromFilter != "" {
|
||||
if dryRun {
|
||||
@@ -780,7 +776,13 @@ func buildListParams(runtime *common.RuntimeContext, mailboxID string, f triageF
|
||||
params["folder_id"] = folderFromFilter
|
||||
}
|
||||
} else {
|
||||
params["folder_id"] = folderFromFilter
|
||||
resolved, err := resolveFolderName(runtime, mailboxID, folderFromFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resolved != "" {
|
||||
params["folder_id"] = resolved
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -799,7 +801,13 @@ func buildListParams(runtime *common.RuntimeContext, mailboxID string, f triageF
|
||||
params["label_id"] = labelIDFromFilter
|
||||
}
|
||||
} else {
|
||||
params["label_id"] = labelIDFromFilter
|
||||
resolved, err := resolveLabelID(runtime, mailboxID, labelIDFromFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resolved != "" {
|
||||
params["label_id"] = resolved
|
||||
}
|
||||
}
|
||||
} else if labelFromFilter != "" {
|
||||
if dryRun {
|
||||
@@ -809,7 +817,13 @@ func buildListParams(runtime *common.RuntimeContext, mailboxID string, f triageF
|
||||
params["label_id"] = labelFromFilter
|
||||
}
|
||||
} else {
|
||||
params["label_id"] = labelFromFilter
|
||||
resolved, err := resolveLabelName(runtime, mailboxID, labelFromFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resolved != "" {
|
||||
params["label_id"] = resolved
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
@@ -975,11 +974,7 @@ func TestBuildListParamsDryRunOnlyUnread(t *testing.T) {
|
||||
func TestBuildListParamsDryRunFolderAlias(t *testing.T) {
|
||||
rt := runtimeForMailTriageTest(t, nil)
|
||||
f := triageFilter{Folder: "sent"}
|
||||
resolved, err := resolveListFilter(rt, "me", f, true)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveListFilter: %v", err)
|
||||
}
|
||||
got, err := buildListParams(rt, "me", resolved, 20, "", true)
|
||||
got, err := buildListParams(rt, "me", f, 20, "", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -988,30 +983,10 @@ func TestBuildListParamsDryRunFolderAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildListParamsDryRunCustomFolderPreservesInput(t *testing.T) {
|
||||
rt := runtimeForMailTriageTest(t, nil)
|
||||
f := triageFilter{Folder: "team-folder"}
|
||||
resolved, err := resolveListFilter(rt, "me", f, true)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveListFilter: %v", err)
|
||||
}
|
||||
got, err := buildListParams(rt, "me", resolved, 20, "", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got["folder_id"] != "team-folder" {
|
||||
t.Fatalf("expected dry-run folder_id=team-folder, got %v", got["folder_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildListParamsDryRunLabelAlias(t *testing.T) {
|
||||
rt := runtimeForMailTriageTest(t, nil)
|
||||
f := triageFilter{Label: "flagged"}
|
||||
resolved, err := resolveListFilter(rt, "me", f, true)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveListFilter: %v", err)
|
||||
}
|
||||
got, err := buildListParams(rt, "me", resolved, 10, "", true)
|
||||
got, err := buildListParams(rt, "me", f, 10, "", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1020,25 +995,6 @@ func TestBuildListParamsDryRunLabelAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildListParamsDryRunCustomLabelPreservesInput(t *testing.T) {
|
||||
rt := runtimeForMailTriageTest(t, nil)
|
||||
f := triageFilter{Label: "custom-label"}
|
||||
resolved, err := resolveListFilter(rt, "me", f, true)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveListFilter: %v", err)
|
||||
}
|
||||
got, err := buildListParams(rt, "me", resolved, 10, "", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := got["folder_id"]; ok {
|
||||
t.Fatalf("folder_id should not be set when label is specified, got %v", got["folder_id"])
|
||||
}
|
||||
if got["label_id"] != "custom-label" {
|
||||
t.Fatalf("expected dry-run label_id=custom-label, got %v", got["label_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildSearchParams additional coverage ---
|
||||
|
||||
func TestBuildSearchParamsAllFilterFields(t *testing.T) {
|
||||
@@ -1522,16 +1478,14 @@ func boolPtr(v bool) *bool { return &v }
|
||||
|
||||
// --- mailbox_id preservation tests ---
|
||||
|
||||
// TestMailTriageStructuredOutputPreservesMailboxID verifies mailbox and notice metadata.
|
||||
func TestMailTriageStructuredOutputPreservesMailboxID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mailbox string
|
||||
format string
|
||||
args []string
|
||||
register func(*httpmock.Registry, string)
|
||||
wantCount int
|
||||
wantNotice string
|
||||
name string
|
||||
mailbox string
|
||||
format string
|
||||
args []string
|
||||
register func(*httpmock.Registry, string)
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "list json default mailbox",
|
||||
@@ -1568,10 +1522,9 @@ func TestMailTriageStructuredOutputPreservesMailboxID(t *testing.T) {
|
||||
register: func(reg *httpmock.Registry, mailbox string) {
|
||||
registerMailTriageSearchStub(reg, mailbox, []interface{}{
|
||||
mailTriageSearchItem("search_pub_001", "Shared search"),
|
||||
}, false, "", "The query is too long and has been truncated to the first 50 characters for search.")
|
||||
}, false, "")
|
||||
},
|
||||
wantCount: 1,
|
||||
wantNotice: "The query is too long and has been truncated to the first 50 characters for search.",
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "empty list json keeps top-level mailbox",
|
||||
@@ -1606,9 +1559,6 @@ func TestMailTriageStructuredOutputPreservesMailboxID(t *testing.T) {
|
||||
if data["mailbox_id"] != tt.mailbox {
|
||||
t.Fatalf("top-level mailbox_id mismatch: got %v, want %q", data["mailbox_id"], tt.mailbox)
|
||||
}
|
||||
if tt.wantNotice != "" && data["notice"] != tt.wantNotice {
|
||||
t.Fatalf("notice mismatch: got %v, want %q", data["notice"], tt.wantNotice)
|
||||
}
|
||||
messages := mailTriageMessagesFromOutput(t, data)
|
||||
if len(messages) != tt.wantCount {
|
||||
t.Fatalf("message count mismatch: got %d, want %d", len(messages), tt.wantCount)
|
||||
@@ -1622,7 +1572,6 @@ func TestMailTriageStructuredOutputPreservesMailboxID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMailTriageMissingMessageMetadataStillGetsMailboxID verifies fallback rows keep mailbox IDs.
|
||||
func TestMailTriageMissingMessageMetadataStillGetsMailboxID(t *testing.T) {
|
||||
f, stdout, _, reg := mailShortcutTestFactory(t)
|
||||
defer reg.Verify(t)
|
||||
@@ -1655,7 +1604,6 @@ func TestMailTriageMissingMessageMetadataStillGetsMailboxID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMailTriageTableOutputPreservesMailboxContext verifies public mailbox table hints.
|
||||
func TestMailTriageTableOutputPreservesMailboxContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1706,33 +1654,6 @@ func TestMailTriageTableOutputPreservesMailboxContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMailTriageDefaultTableOutputPrintsSearchNoticeToStderr verifies stderr notices.
|
||||
func TestMailTriageDefaultTableOutputPrintsSearchNoticeToStderr(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
f, stdout, stderr, reg := mailShortcutTestFactory(t)
|
||||
defer reg.Verify(t)
|
||||
|
||||
registerMailTriageSearchStub(reg, "me", []interface{}{
|
||||
mailTriageSearchItem("msg_search_notice", "Search notice result"),
|
||||
}, false, "", notice)
|
||||
|
||||
if err := runMountedMailShortcut(t, MailTriage, []string{
|
||||
"+triage",
|
||||
"--query", strings.Repeat("q", 81),
|
||||
}, f, stdout); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if out := stdout.String(); !strings.Contains(out, "msg_search_notice") {
|
||||
t.Fatalf("stdout should contain table row, got:\n%s", out)
|
||||
}
|
||||
if errOut := stderr.String(); !strings.Contains(errOut, "notice: "+notice) {
|
||||
t.Fatalf("stderr should contain search notice, got:\n%s", errOut)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMailTriageJSONOutput decodes structured triage output for assertions.
|
||||
func decodeMailTriageJSONOutput(t *testing.T, stdout interface{ Bytes() []byte }) map[string]interface{} {
|
||||
t.Helper()
|
||||
var data map[string]interface{}
|
||||
@@ -1742,7 +1663,6 @@ func decodeMailTriageJSONOutput(t *testing.T, stdout interface{ Bytes() []byte }
|
||||
return data
|
||||
}
|
||||
|
||||
// mailTriageMessagesFromOutput extracts triage messages as object maps.
|
||||
func mailTriageMessagesFromOutput(t *testing.T, data map[string]interface{}) []map[string]interface{} {
|
||||
t.Helper()
|
||||
rawMessages, ok := data["messages"].([]interface{})
|
||||
@@ -1795,8 +1715,7 @@ func registerMailTriageBatchStub(reg *httpmock.Registry, mailbox string, message
|
||||
})
|
||||
}
|
||||
|
||||
// registerMailTriageSearchStub registers a mailbox search response for triage tests.
|
||||
func registerMailTriageSearchStub(reg *httpmock.Registry, mailbox string, items []interface{}, hasMore bool, pageToken string, notices ...string) {
|
||||
func registerMailTriageSearchStub(reg *httpmock.Registry, mailbox string, items []interface{}, hasMore bool, pageToken string) {
|
||||
data := map[string]interface{}{
|
||||
"items": items,
|
||||
"has_more": hasMore,
|
||||
@@ -1804,9 +1723,6 @@ func registerMailTriageSearchStub(reg *httpmock.Registry, mailbox string, items
|
||||
if pageToken != "" {
|
||||
data["page_token"] = pageToken
|
||||
}
|
||||
if len(notices) > 0 && notices[0] != "" {
|
||||
data["notice"] = notices[0]
|
||||
}
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: mailboxPath(mailbox, "search"),
|
||||
@@ -1835,137 +1751,3 @@ func mailTriageSearchItem(messageID, subject string) map[string]interface{} {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// registerMailTriageFoldersListStub registers a NON-reusable stub for the
|
||||
// mailbox folders list API. Because it is non-reusable, any second hit returns
|
||||
// "httpmock: no stub for GET .../folders" — which is exactly the assertion we
|
||||
// use to prove resolveListFilter runs once and buildListParams does NOT
|
||||
// re-resolve. folderID/folderName is the single custom folder the API reports.
|
||||
func registerMailTriageFoldersListStub(reg *httpmock.Registry, mailbox, folderID, folderName string) {
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: mailboxPath(mailbox, "folders"),
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"data": map[string]interface{}{
|
||||
"items": []interface{}{
|
||||
map[string]interface{}{
|
||||
"id": folderID,
|
||||
"name": folderName,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// registerMailTriageListPageStub registers one page of the messages list API,
|
||||
// disambiguated from sibling pages by a URL substring unique to that page
|
||||
// (e.g. "page_size=5" for page 1 vs "page_size=2" for page 2). The substring
|
||||
// must NOT depend on query-param ordering: map iteration makes param order
|
||||
// nondeterministic, so prefer a value-only token like "page_size=N" (the N
|
||||
// differs per page because pageSize = maxCount - fetched_so_far). Non-reusable
|
||||
// so reg.Verify catches under- or over-consumption.
|
||||
func registerMailTriageListPageStub(reg *httpmock.Registry, urlSubstring string, items []string, hasMore bool, pageToken string) {
|
||||
data := map[string]interface{}{
|
||||
"items": items,
|
||||
"has_more": hasMore,
|
||||
}
|
||||
if pageToken != "" {
|
||||
data["page_token"] = pageToken
|
||||
}
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: urlSubstring,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"data": data,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// TestMailTriageCustomFolderResolvesOnceAcrossListPages is the regression test
|
||||
// for the bug where buildListParams re-called resolveFolderID on every list
|
||||
// page, turning "resolve once" into "1 + page_count" folder-list API calls and
|
||||
// easily tripping rate limits.
|
||||
//
|
||||
// Setup: a custom folder filter that forces resolveListFilter to hit the
|
||||
// folders list API once (to map folder name "team-folder" to folder_id), then two
|
||||
// messages-list pages. The folders list stub is non-reusable, so if
|
||||
// buildListParams re-resolves, the second hit fails with "no stub". The
|
||||
// messages-list stubs are page-specific (disambiguated by page_size in the
|
||||
// URL), so both pages are served and Verify asserts each fired exactly once.
|
||||
func TestMailTriageCustomFolderResolvesOnceAcrossListPages(t *testing.T) {
|
||||
f, stdout, _, reg := mailShortcutTestFactory(t)
|
||||
defer reg.Verify(t)
|
||||
|
||||
// listMailboxFolders (called once by resolveListFilter) gates on the
|
||||
// mail:user_mailbox.folder:read scope, which the default test token does
|
||||
// not carry. Re-store the token with that scope appended so the folders
|
||||
// API call is actually exercised (and thus the non-reusable folders stub
|
||||
// is the load-bearing "exactly once" assertion).
|
||||
const folderScope = "mail:user_mailbox.folder:read"
|
||||
cfg := mailTestConfig()
|
||||
if stored := auth.GetStoredToken(cfg.AppID, cfg.UserOpenId); stored != nil {
|
||||
if !strings.Contains(stored.Scope, folderScope) {
|
||||
stored.Scope = stored.Scope + " " + folderScope
|
||||
if err := auth.SetStoredToken(stored); err != nil {
|
||||
t.Fatalf("re-store token with folder scope: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
mailbox = "me"
|
||||
folderName = "team-folder"
|
||||
folderID = "fld_custom_team"
|
||||
page2Token = "tok_page2"
|
||||
)
|
||||
// --max 5 with listPageMax=20 → pageSize = 5-0 = 5 on page 1, then 5-3 = 2
|
||||
// on page 2. The page_size query value disambiguates the two list stubs.
|
||||
page1IDs := []string{"msg_a", "msg_b", "msg_c"}
|
||||
page2IDs := []string{"msg_d", "msg_e"}
|
||||
|
||||
// Folders list: registered exactly once, non-reusable. Any second folder
|
||||
// lookup (the bug) fails the test with "no stub for GET .../folders".
|
||||
registerMailTriageFoldersListStub(reg, mailbox, folderID, folderName)
|
||||
// Messages list, page 1: 3 ids, has_more, hands off a page-2 token. The
|
||||
// page_size value (5 = maxCount - 0) is unique to page 1; page 2 uses 2.
|
||||
registerMailTriageListPageStub(reg, "page_size=5", page1IDs, true, page2Token)
|
||||
// Messages list, page 2: 2 ids, terminal.
|
||||
registerMailTriageListPageStub(reg, "page_size=2", page2IDs, false, "")
|
||||
// Batch metadata fetch for all 5 ids.
|
||||
registerMailTriageBatchStub(reg, mailbox, []map[string]interface{}{
|
||||
mailTriageBatchMessage("msg_a", "Subject A"),
|
||||
mailTriageBatchMessage("msg_b", "Subject B"),
|
||||
mailTriageBatchMessage("msg_c", "Subject C"),
|
||||
mailTriageBatchMessage("msg_d", "Subject D"),
|
||||
mailTriageBatchMessage("msg_e", "Subject E"),
|
||||
})
|
||||
|
||||
args := []string{
|
||||
"+triage",
|
||||
"--as", "user",
|
||||
"--mailbox", mailbox,
|
||||
"--filter", `{"folder":"` + folderName + `"}`,
|
||||
"--max", "5",
|
||||
"--format", "json",
|
||||
}
|
||||
if err := runMountedMailShortcut(t, MailTriage, args, f, stdout); err != nil {
|
||||
t.Fatalf("unexpected error running +triage (likely a second folders API call — the bug): %v", err)
|
||||
}
|
||||
|
||||
data := decodeMailTriageJSONOutput(t, stdout)
|
||||
messages := mailTriageMessagesFromOutput(t, data)
|
||||
if len(messages) != 5 {
|
||||
t.Fatalf("expected 5 messages across 2 pages, got %d (stdout=%s)", len(messages), stdout.String())
|
||||
}
|
||||
if got := data["has_more"]; got != false {
|
||||
t.Fatalf("expected has_more=false after exhausting pages, got %v", got)
|
||||
}
|
||||
// All registered stubs (1 folders + 2 list pages + 1 batch_get) are
|
||||
// non-reusable; reg.Verify (deferred above) asserts each was matched
|
||||
// exactly once. Combined with the non-reusable folders stub, this is the
|
||||
// proof that the folders list API was called exactly once across both
|
||||
// pages — the core invariant the fix restores.
|
||||
}
|
||||
|
||||
@@ -308,9 +308,6 @@ var MinutesSearch = common.Shortcut{
|
||||
"has_more": data["has_more"],
|
||||
"page_token": data["page_token"],
|
||||
}
|
||||
if notice, _ := data["notice"].(string); notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
|
||||
runtime.OutFormat(outData, &output.Meta{Count: len(rows)}, func(w io.Writer) {
|
||||
if len(rows) == 0 {
|
||||
|
||||
@@ -609,8 +609,6 @@ func TestMinutesSearchExecuteShowsPaginationHintForTableFormat(t *testing.T) {
|
||||
func TestMinutesSearchExecuteJSONCountUsesRenderedRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
@@ -619,7 +617,6 @@ func TestMinutesSearchExecuteJSONCountUsesRenderedRows(t *testing.T) {
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"items": []interface{}{
|
||||
nil,
|
||||
map[string]interface{}{
|
||||
@@ -644,9 +641,6 @@ func TestMinutesSearchExecuteJSONCountUsesRenderedRows(t *testing.T) {
|
||||
reg.Verify(t)
|
||||
|
||||
var envelope struct {
|
||||
Data struct {
|
||||
Notice string `json:"notice"`
|
||||
} `json:"data"`
|
||||
Meta struct {
|
||||
Count int `json:"count"`
|
||||
} `json:"meta"`
|
||||
@@ -657,9 +651,6 @@ func TestMinutesSearchExecuteJSONCountUsesRenderedRows(t *testing.T) {
|
||||
if envelope.Meta.Count != 1 {
|
||||
t.Fatalf("meta.count = %d, want 1", envelope.Meta.Count)
|
||||
}
|
||||
if envelope.Data.Notice != notice {
|
||||
t.Fatalf("data.notice = %q, want %q", envelope.Data.Notice, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMinuteSearchFieldExtractors verifies field extractors read populated metadata correctly.
|
||||
|
||||
@@ -242,6 +242,7 @@ func Shortcuts() []common.Shortcut {
|
||||
GetMyTasks,
|
||||
GetRelatedTasks,
|
||||
SearchTask,
|
||||
SubscribeTaskEvent,
|
||||
UploadAttachmentTask,
|
||||
CreateTasklist,
|
||||
SearchTasklist,
|
||||
|
||||
@@ -73,16 +73,12 @@ var SearchTask = common.Shortcut{
|
||||
var rawItems []interface{}
|
||||
var lastPageToken string
|
||||
var lastHasMore bool
|
||||
var notice string
|
||||
currentBody := body
|
||||
for page := 0; page < pageLimit; page++ {
|
||||
data, err := callTaskAPITyped(runtime, http.MethodPost, "/open-apis/task/v2/tasks/search", nil, currentBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notice == "" {
|
||||
notice, _ = data["notice"].(string)
|
||||
}
|
||||
items, _ := data["items"].([]interface{})
|
||||
rawItems = append(rawItems, items...)
|
||||
lastHasMore, _ = data["has_more"].(bool)
|
||||
@@ -119,9 +115,6 @@ var SearchTask = common.Shortcut{
|
||||
"page_token": lastPageToken,
|
||||
"has_more": lastHasMore,
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
runtime.OutFormat(outData, &output.Meta{Count: len(enriched)}, func(w io.Writer) {
|
||||
if len(enriched) == 0 {
|
||||
fmt.Fprintln(w, "No tasks found.")
|
||||
|
||||
@@ -153,7 +153,6 @@ func TestSearchTask_DryRun(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchTask_Execute verifies task search output, enrichment, and notices.
|
||||
func TestSearchTask_Execute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -172,7 +171,6 @@ func TestSearchTask_Execute(t *testing.T) {
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"notice": "The query is too long and has been truncated to the first 50 characters for search.",
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
"items": []interface{}{
|
||||
@@ -193,7 +191,7 @@ func TestSearchTask_Execute(t *testing.T) {
|
||||
},
|
||||
})
|
||||
},
|
||||
wantParts: []string{`"guid": "task-123"`, `"summary": "Search Result"`, `"notice": "The query is too long and has been truncated to the first 50 characters for search."`},
|
||||
wantParts: []string{`"guid": "task-123"`, `"summary": "Search Result"`},
|
||||
},
|
||||
{
|
||||
name: "fallback to app link",
|
||||
|
||||
40
shortcuts/task/task_subscribe_event.go
Normal file
40
shortcuts/task/task_subscribe_event.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
var SubscribeTaskEvent = common.Shortcut{
|
||||
Service: "task",
|
||||
Command: "+subscribe-event",
|
||||
Description: "subscribe to task events",
|
||||
Risk: "write",
|
||||
Scopes: []string{"task:task:read"},
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
HasFormat: true,
|
||||
DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
return common.NewDryRunAPI().
|
||||
POST("/open-apis/task/v2/task_v2/task_subscription").
|
||||
Params(map[string]interface{}{"user_id_type": "open_id"})
|
||||
},
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
params := map[string]interface{}{"user_id_type": "open_id"}
|
||||
if _, err := callTaskAPITyped(runtime, http.MethodPost, "/open-apis/task/v2/task_v2/task_subscription", params, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
outData := map[string]interface{}{"ok": true}
|
||||
runtime.OutFormat(outData, nil, func(w io.Writer) {
|
||||
fmt.Fprintln(w, "✅ Task event subscription created successfully!")
|
||||
})
|
||||
return nil
|
||||
},
|
||||
}
|
||||
163
shortcuts/task/task_subscribe_event_test.go
Normal file
163
shortcuts/task/task_subscribe_event_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
func TestSubscribeTaskEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
args []string
|
||||
register func(*httpmock.Registry)
|
||||
wantErr bool
|
||||
wantParts []string
|
||||
}{
|
||||
{
|
||||
name: "execute json (user identity)",
|
||||
mode: "execute",
|
||||
args: []string{"+subscribe-event", "--as", "user", "--format", "json"},
|
||||
register: func(reg *httpmock.Registry) {
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/task/v2/task_v2/task_subscription",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
},
|
||||
wantParts: []string{`"ok": true`},
|
||||
},
|
||||
{
|
||||
name: "execute json (bot identity)",
|
||||
mode: "execute",
|
||||
args: []string{"+subscribe-event", "--as", "bot", "--format", "json"},
|
||||
register: func(reg *httpmock.Registry) {
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/task/v2/task_v2/task_subscription",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
},
|
||||
wantParts: []string{`"ok": true`},
|
||||
},
|
||||
{
|
||||
name: "execute api error",
|
||||
mode: "execute",
|
||||
args: []string{"+subscribe-event", "--as", "bot", "--format", "json"},
|
||||
register: func(reg *httpmock.Registry) {
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/task/v2/task_v2/task_subscription",
|
||||
Body: map[string]interface{}{
|
||||
"code": 401,
|
||||
"msg": "Unauthorized",
|
||||
"error": map[string]interface{}{
|
||||
"log_id": "test-log-id",
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
wantParts: []string{"Unauthorized"},
|
||||
},
|
||||
{
|
||||
name: "dry run",
|
||||
mode: "dryrun",
|
||||
wantParts: []string{"POST /open-apis/task/v2/task_v2/task_subscription"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
switch tt.mode {
|
||||
case "execute":
|
||||
f, stdout, _, reg := taskShortcutTestFactory(t)
|
||||
warmTenantToken(t, f, reg)
|
||||
if tt.register != nil {
|
||||
tt.register(reg)
|
||||
}
|
||||
|
||||
err := runMountedTaskShortcut(t, SubscribeTaskEvent, tt.args, f, stdout)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
out := err.Error()
|
||||
for _, want := range tt.wantParts {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Fatalf("error missing %q: %s", want, out)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("runMountedTaskShortcut() error = %v", err)
|
||||
}
|
||||
|
||||
out := stdout.String()
|
||||
outNorm := strings.ReplaceAll(out, `":"`, `": "`)
|
||||
for _, want := range tt.wantParts {
|
||||
if !strings.Contains(out, want) && !strings.Contains(outNorm, want) {
|
||||
t.Fatalf("output missing %q: %s", want, out)
|
||||
}
|
||||
}
|
||||
case "dryrun":
|
||||
runtime := common.TestNewRuntimeContextWithIdentity(&cobra.Command{Use: "test"}, taskTestConfig(t), "user")
|
||||
out := SubscribeTaskEvent.DryRun(nil, runtime).Format()
|
||||
for _, want := range tt.wantParts {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Fatalf("dry run output missing %q: %s", want, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscribeTaskEvent_MalformedResponse covers the parse-response arm: a 200
|
||||
// with an unparseable body surfaces a typed internal invalid_response error
|
||||
// (exit 5).
|
||||
func TestSubscribeTaskEvent_MalformedResponse(t *testing.T) {
|
||||
f, stdout, _, reg := taskShortcutTestFactory(t)
|
||||
warmTenantToken(t, f, reg)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/task/v2/task_v2/task_subscription",
|
||||
Status: 200,
|
||||
RawBody: []byte("{not-json"),
|
||||
})
|
||||
|
||||
args := []string{"+subscribe-event", "--as", "bot", "--format", "json"}
|
||||
err := runMountedTaskShortcut(t, SubscribeTaskEvent, args, f, stdout)
|
||||
|
||||
var ie *errs.InternalError
|
||||
if !errors.As(err, &ie) {
|
||||
t.Fatalf("err = %T, want *errs.InternalError; err = %v", err, err)
|
||||
}
|
||||
if ie.Subtype != errs.SubtypeInvalidResponse {
|
||||
t.Errorf("subtype = %q, want %q", ie.Subtype, errs.SubtypeInvalidResponse)
|
||||
}
|
||||
if got := output.ExitCodeOf(err); got != output.ExitInternal {
|
||||
t.Errorf("exit code = %d, want %d", got, output.ExitInternal)
|
||||
}
|
||||
}
|
||||
@@ -70,16 +70,12 @@ var SearchTasklist = common.Shortcut{
|
||||
var rawItems []interface{}
|
||||
var lastPageToken string
|
||||
var lastHasMore bool
|
||||
var notice string
|
||||
currentBody := body
|
||||
for page := 0; page < pageLimit; page++ {
|
||||
data, err := callTaskAPITyped(runtime, http.MethodPost, "/open-apis/task/v2/tasklists/search", nil, currentBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notice == "" {
|
||||
notice, _ = data["notice"].(string)
|
||||
}
|
||||
items, _ := data["items"].([]interface{})
|
||||
rawItems = append(rawItems, items...)
|
||||
lastHasMore, _ = data["has_more"].(bool)
|
||||
@@ -122,9 +118,6 @@ var SearchTasklist = common.Shortcut{
|
||||
"page_token": lastPageToken,
|
||||
"has_more": lastHasMore,
|
||||
}
|
||||
if notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
runtime.OutFormat(outData, &output.Meta{Count: len(tasklists)}, func(w io.Writer) {
|
||||
if len(tasklists) == 0 {
|
||||
fmt.Fprintln(w, "No tasklists found.")
|
||||
|
||||
@@ -126,7 +126,6 @@ func TestSearchTasklist_DryRun(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchTasklist_Execute verifies tasklist search output, enrichment, and notices.
|
||||
func TestSearchTasklist_Execute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -145,7 +144,6 @@ func TestSearchTasklist_Execute(t *testing.T) {
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"notice": "The query is too long and has been truncated to the first 50 characters for search.",
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
"items": []interface{}{map[string]interface{}{"id": "tl-123"}},
|
||||
@@ -164,7 +162,7 @@ func TestSearchTasklist_Execute(t *testing.T) {
|
||||
},
|
||||
})
|
||||
},
|
||||
wantParts: []string{`"guid": "tl-123"`, `"name": "Q2 Plan"`, `"notice": "The query is too long and has been truncated to the first 50 characters for search."`},
|
||||
wantParts: []string{`"guid": "tl-123"`, `"name": "Q2 Plan"`},
|
||||
},
|
||||
{
|
||||
name: "fallback on detail error",
|
||||
|
||||
@@ -236,9 +236,6 @@ var VCSearch = common.Shortcut{
|
||||
"has_more": data["has_more"],
|
||||
"page_token": data["page_token"],
|
||||
}
|
||||
if notice, _ := data["notice"].(string); notice != "" {
|
||||
outData["notice"] = notice
|
||||
}
|
||||
hasMore, _ := data["has_more"].(bool)
|
||||
runtime.OutFormat(outData, &output.Meta{Count: len(items)}, func(w io.Writer) {
|
||||
if len(items) == 0 {
|
||||
|
||||
@@ -5,7 +5,6 @@ package vc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
@@ -255,7 +253,6 @@ func TestSearch_Validation_InvalidPageSize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearch_DryRun verifies meeting search dry-run includes the API path.
|
||||
func TestSearch_DryRun(t *testing.T) {
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, defaultConfig())
|
||||
err := mountAndRun(t, VCSearch, []string{"+search", "--query", "test", "--dry-run", "--as", "user"}, f, stdout)
|
||||
@@ -267,43 +264,6 @@ func TestSearch_DryRun(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearch_ExecutePassesThroughNotice verifies meeting search notice output.
|
||||
func TestSearch_ExecutePassesThroughNotice(t *testing.T) {
|
||||
const notice = "The query is too long and has been truncated to the first 50 characters for search."
|
||||
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/vc/v1/meetings/search",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"notice": notice,
|
||||
"items": []interface{}{},
|
||||
"total": 0,
|
||||
"has_more": false,
|
||||
"page_token": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := mountAndRun(t, VCSearch, []string{"+search", "--query", "incident", "--format", "json", "--as", "user"}, f, stdout); err != nil {
|
||||
t.Fatalf("VCSearch.Execute() error = %v", err)
|
||||
}
|
||||
reg.Verify(t)
|
||||
|
||||
var env map[string]interface{}
|
||||
if err := json.Unmarshal(stdout.Bytes(), &env); err != nil {
|
||||
t.Fatalf("json.Unmarshal(stdout) error = %v\nstdout=%s", err, stdout.String())
|
||||
}
|
||||
data, _ := env["data"].(map[string]interface{})
|
||||
if got, _ := data["notice"].(string); got != notice {
|
||||
t.Fatalf("data.notice = %q, want %q; data=%#v", got, notice, data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearch_InvalidTimeRange verifies invalid meeting search time input fails.
|
||||
func TestSearch_InvalidTimeRange(t *testing.T) {
|
||||
f, _, _, _ := cmdutil.TestFactory(t, defaultConfig())
|
||||
err := mountAndRun(t, VCSearch, []string{"+search", "--start", "bad-time", "--as", "user"}, f, nil)
|
||||
|
||||
@@ -5,7 +5,6 @@ package whiteboard
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -22,54 +21,40 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// WhiteboardQueryAsImage exports a whiteboard preview image.
|
||||
WhiteboardQueryAsImage = "image"
|
||||
// WhiteboardQueryAsSvg exports a whiteboard as SVG.
|
||||
WhiteboardQueryAsSvg = "svg"
|
||||
// WhiteboardQueryAsCode exports Mermaid or PlantUML source extracted from the whiteboard.
|
||||
WhiteboardQueryAsCode = "code"
|
||||
// WhiteboardQueryAsRaw exports the raw whiteboard node payload.
|
||||
WhiteboardQueryAsRaw = "raw"
|
||||
WhiteboardQueryAsCode = "code"
|
||||
WhiteboardQueryAsRaw = "raw"
|
||||
)
|
||||
|
||||
// SyntaxType identifies the diagram syntax extracted from whiteboard code blocks.
|
||||
type SyntaxType int
|
||||
|
||||
const (
|
||||
// SyntaxTypePlantUML marks PlantUML code blocks.
|
||||
SyntaxTypePlantUML SyntaxType = 1
|
||||
// SyntaxTypeMermaid marks Mermaid code blocks.
|
||||
SyntaxTypeMermaid SyntaxType = 2
|
||||
SyntaxTypeMermaid SyntaxType = 2
|
||||
)
|
||||
|
||||
// SyntaxTypeNameMap maps whiteboard syntax types to their CLI output names.
|
||||
var SyntaxTypeNameMap = map[SyntaxType]string{
|
||||
SyntaxTypePlantUML: "plantuml",
|
||||
SyntaxTypeMermaid: "mermaid",
|
||||
}
|
||||
|
||||
// SyntaxTypeExtensionMap maps whiteboard syntax types to their default file extensions.
|
||||
var SyntaxTypeExtensionMap = map[SyntaxType]string{
|
||||
SyntaxTypePlantUML: ".puml",
|
||||
SyntaxTypeMermaid: ".mmd",
|
||||
}
|
||||
|
||||
// String returns the CLI-facing name for the syntax type.
|
||||
func (s SyntaxType) String() string {
|
||||
return SyntaxTypeNameMap[s]
|
||||
}
|
||||
|
||||
// ExtensionName returns the default file extension for the syntax type.
|
||||
func (s SyntaxType) ExtensionName() string {
|
||||
return SyntaxTypeExtensionMap[s]
|
||||
}
|
||||
|
||||
// IsValid reports whether the syntax type is one of the supported whiteboard code syntaxes.
|
||||
func (s SyntaxType) IsValid() bool {
|
||||
return s == SyntaxTypePlantUML || s == SyntaxTypeMermaid
|
||||
}
|
||||
|
||||
// WhiteboardQuery registers the `whiteboard +query` shortcut.
|
||||
var WhiteboardQuery = common.Shortcut{
|
||||
Service: "whiteboard",
|
||||
Command: "+query",
|
||||
@@ -79,8 +64,8 @@ var WhiteboardQuery = common.Shortcut{
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
Flags: []common.Flag{
|
||||
{Name: "whiteboard-token", Desc: "whiteboard token of the whiteboard. You will need read permission to download preview image.", Required: true},
|
||||
{Name: "output_as", Desc: "output whiteboard as: image | svg | code | raw.", Required: true},
|
||||
{Name: "output", Desc: "output directory. It is required when output as image. If not specified when --output_as svg/code/raw, it will output directly.", Required: false},
|
||||
{Name: "output_as", Desc: "output whiteboard as: image | code | raw.", Required: true},
|
||||
{Name: "output", Desc: "output directory. It is required when output as image. If not specified when --output_as code/raw, it will output directly.", Required: false},
|
||||
{Name: "overwrite", Desc: "overwrite existing file if it exists", Required: false, Type: "bool"},
|
||||
},
|
||||
HasFormat: true,
|
||||
@@ -101,8 +86,8 @@ var WhiteboardQuery = common.Shortcut{
|
||||
}
|
||||
|
||||
as := runtime.Str("output_as")
|
||||
if as != WhiteboardQueryAsImage && as != WhiteboardQueryAsSvg && as != WhiteboardQueryAsCode && as != WhiteboardQueryAsRaw {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--output_as flag must be one of: image | svg | code | raw").WithParam("--output_as")
|
||||
if as != WhiteboardQueryAsImage && as != WhiteboardQueryAsCode && as != WhiteboardQueryAsRaw {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--output_as flag must be one of: image | code | raw").WithParam("--output_as")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -122,13 +107,8 @@ var WhiteboardQuery = common.Shortcut{
|
||||
return common.NewDryRunAPI().
|
||||
GET(fmt.Sprintf("/open-apis/board/v1/whiteboards/%s/nodes", common.MaskToken(url.PathEscape(token)))).
|
||||
Desc("Extract raw nodes structure from given whiteboard")
|
||||
case WhiteboardQueryAsSvg:
|
||||
return common.NewDryRunAPI().
|
||||
POST(fmt.Sprintf("/open-apis/board/v1/whiteboards/%s/export", common.MaskToken(url.PathEscape(token)))).
|
||||
Body(map[string]string{"export_type": "svg"}).
|
||||
Desc("Export SVG of given whiteboard")
|
||||
default:
|
||||
return common.NewDryRunAPI().Desc("invalid --output_as flag, must be one of: image | svg | code | raw")
|
||||
return common.NewDryRunAPI().Desc("invalid --output_as flag, must be one of: image | code | raw")
|
||||
}
|
||||
},
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
@@ -139,110 +119,17 @@ var WhiteboardQuery = common.Shortcut{
|
||||
switch as {
|
||||
case WhiteboardQueryAsImage:
|
||||
return exportWhiteboardPreview(ctx, runtime, token, outDir)
|
||||
case WhiteboardQueryAsSvg:
|
||||
return exportWhiteboardSvg(runtime, token, outDir)
|
||||
case WhiteboardQueryAsCode:
|
||||
return exportWhiteboardCode(runtime, token, outDir)
|
||||
case WhiteboardQueryAsRaw:
|
||||
return exportWhiteboardRaw(runtime, token, outDir)
|
||||
default:
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--output_as flag must be one of: image | svg | code | raw").WithParam("--output_as")
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--output_as flag must be one of: image | code | raw").WithParam("--output_as")
|
||||
}
|
||||
|
||||
},
|
||||
}
|
||||
|
||||
// exportReq defines the request body for whiteboard export APIs.
|
||||
type exportReq struct {
|
||||
ExportType string `json:"export_type"`
|
||||
}
|
||||
|
||||
// exportResp models the whiteboard export response envelope.
|
||||
type exportResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data struct {
|
||||
Content string `json:"content"`
|
||||
MimeType string `json:"mime_type"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// exportWhiteboardSvg exports a whiteboard as SVG and writes the result to stdout or a file.
|
||||
// It requests the SVG export for the given whiteboard token and saves the decoded content when an output path is provided.
|
||||
func exportWhiteboardSvg(runtime *common.RuntimeContext, wbToken, outDir string) error {
|
||||
reqBody := exportReq{ExportType: "svg"}
|
||||
req := &larkcore.ApiReq{
|
||||
HttpMethod: http.MethodPost,
|
||||
ApiPath: fmt.Sprintf("/open-apis/board/v1/whiteboards/%s/export", url.PathEscape(wbToken)),
|
||||
Body: reqBody,
|
||||
}
|
||||
|
||||
resp, err := runtime.DoAPI(req)
|
||||
if err != nil {
|
||||
return wrapWbNetworkErr(err, "export whiteboard svg failed: %v", err)
|
||||
}
|
||||
|
||||
var exportData exportResp
|
||||
if err := json.Unmarshal(resp.RawBody, &exportData); err == nil {
|
||||
if exportData.Code != 0 {
|
||||
subtype := errs.SubtypeUnknown
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
subtype = errs.SubtypeNotFound
|
||||
}
|
||||
return errs.NewAPIError(subtype, "export whiteboard svg failed: %s", exportData.Msg).WithCode(exportData.Code)
|
||||
}
|
||||
} else if resp.StatusCode == http.StatusOK {
|
||||
return errs.NewInternalError(errs.SubtypeInvalidResponse, "parse export response failed: %v", err).WithCause(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body := common.TruncateStr(strings.TrimSpace(string(resp.RawBody)), 500)
|
||||
if resp.StatusCode >= 500 {
|
||||
return errs.NewNetworkError(errs.SubtypeNetworkServer, "export whiteboard svg failed: HTTP %d: %s", resp.StatusCode, body).
|
||||
WithCode(resp.StatusCode).
|
||||
WithRetryable()
|
||||
}
|
||||
subtype := errs.SubtypeUnknown
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
subtype = errs.SubtypeNotFound
|
||||
}
|
||||
return errs.NewAPIError(subtype, "export whiteboard svg failed: HTTP %d: %s", resp.StatusCode, body).
|
||||
WithCode(resp.StatusCode)
|
||||
}
|
||||
|
||||
svgBytes, err := base64.StdEncoding.DecodeString(exportData.Data.Content)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.SubtypeInvalidResponse, "decode svg base64 failed: %v", err).WithCause(err)
|
||||
}
|
||||
|
||||
if outDir == "" {
|
||||
runtime.OutFormat(map[string]interface{}{
|
||||
"svg_content": string(svgBytes),
|
||||
}, nil, func(w io.Writer) {
|
||||
fmt.Fprintf(w, "%s\n", string(svgBytes))
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
finalPath, size, err := saveOutputFile(outDir, ".svg", wbToken, runtime, bytes.NewReader(svgBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runtime.OutFormat(map[string]interface{}{
|
||||
"svg_path": finalPath,
|
||||
"size_bytes": size,
|
||||
}, nil, func(w io.Writer) {
|
||||
fmt.Fprintf(w, "SVG saved to %s\n", finalPath)
|
||||
fmt.Fprintf(w, "File size: %d bytes", size)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// exportWhiteboardPreview downloads a whiteboard preview image and saves it as a PNG file.
|
||||
//
|
||||
// It reports the saved file path and image size on success.
|
||||
// Returns an error if the API request fails, the response is rejected, or the file cannot be saved.
|
||||
func exportWhiteboardPreview(ctx context.Context, runtime *common.RuntimeContext, wbToken, outDir string) error {
|
||||
req := &larkcore.ApiReq{
|
||||
HttpMethod: http.MethodGet,
|
||||
@@ -444,9 +331,6 @@ func exportWhiteboardRaw(runtime *common.RuntimeContext, wbToken, outDir string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveOutputFile writes exported content to a file or directory and returns the final path and written size.
|
||||
// If outPath is a directory, it creates a file named whiteboard_<token><ext>. If outPath is a file path,
|
||||
// it adjusts the file extension to ext, validates the path, and respects the overwrite flag.
|
||||
func saveOutputFile(outPath, ext, token string, runtime *common.RuntimeContext, data io.Reader) (string, int64, error) {
|
||||
// Step 1: Get final output path
|
||||
info, err := runtime.FileIO().Stat(outPath)
|
||||
@@ -483,8 +367,6 @@ func saveOutputFile(outPath, ext, token string, runtime *common.RuntimeContext,
|
||||
switch ext {
|
||||
case ".png":
|
||||
contentType = "image/png"
|
||||
case ".svg":
|
||||
contentType = "image/svg+xml"
|
||||
case ".json":
|
||||
contentType = "application/json"
|
||||
case ".mmd", ".puml":
|
||||
|
||||
@@ -6,8 +6,6 @@ package whiteboard
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -15,7 +13,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/errs"
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
@@ -23,7 +20,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TestSyntaxType verifies syntax names, extensions, and validity checks.
|
||||
func TestSyntaxType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -79,7 +75,6 @@ func TestSyntaxType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_Validate verifies query flag validation for supported output modes.
|
||||
func TestWhiteboardQuery_Validate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
chdirTemp(t)
|
||||
@@ -204,9 +199,6 @@ func TestWhiteboardQuery_Validate_TypedErrors(t *testing.T) {
|
||||
if p.Category != errs.CategoryValidation {
|
||||
t.Errorf("Category = %q, want %q", p.Category, errs.CategoryValidation)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Problem subtype = %q, want %q", p.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -240,7 +232,6 @@ func TestExportWhiteboardPreview_HTTPError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardPreview_HTTPNotFoundIsAPIError verifies 404 preview downloads surface as typed API errors.
|
||||
func TestExportWhiteboardPreview_HTTPNotFoundIsAPIError(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
chdirTemp(t)
|
||||
@@ -264,7 +255,6 @@ func TestExportWhiteboardPreview_HTTPNotFoundIsAPIError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_DryRun verifies dry-run output for the supported query modes.
|
||||
func TestWhiteboardQuery_DryRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -317,64 +307,6 @@ func TestWhiteboardQuery_DryRun(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_DryRun_InvalidOutputAs verifies dry-run guidance for unsupported output modes.
|
||||
func TestWhiteboardQuery_DryRun_InvalidOutputAs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
rt := newTestRuntime(map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"output_as": "invalid",
|
||||
}, nil)
|
||||
|
||||
dryRun := WhiteboardQuery.DryRun(ctx, rt)
|
||||
if dryRun == nil {
|
||||
t.Fatal("WhiteboardQuery.DryRun() returned nil")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(dryRun)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
if !strings.Contains(string(data), "image | svg | code | raw") {
|
||||
t.Fatalf("dry run desc = %s, want invalid output_as guidance", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_Execute_InvalidOutputAs_TypedError verifies invalid output modes return typed validation errors.
|
||||
func TestWhiteboardQuery_Execute_InvalidOutputAs_TypedError(t *testing.T) {
|
||||
rt := newTestRuntime(map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"output_as": "invalid",
|
||||
}, nil)
|
||||
|
||||
err := WhiteboardQuery.Execute(context.Background(), rt)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
var ve *errs.ValidationError
|
||||
if !errors.As(err, &ve) {
|
||||
t.Fatalf("error is not *errs.ValidationError: %T (%v)", err, err)
|
||||
}
|
||||
if ve.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Subtype = %q, want %q", ve.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if ve.Param != "--output_as" {
|
||||
t.Errorf("Param = %q, want %q", ve.Param, "--output_as")
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("errs.ProblemOf returned false")
|
||||
}
|
||||
if p.Category != errs.CategoryValidation {
|
||||
t.Errorf("Category = %q, want %q", p.Category, errs.CategoryValidation)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Problem subtype = %q, want %q", p.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_ShortcutRegistration verifies the whiteboard query shortcut metadata.
|
||||
func TestWhiteboardQuery_ShortcutRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -393,7 +325,6 @@ func TestWhiteboardQuery_ShortcutRegistration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveOutputFile verifies output saving, overwrite handling, and extension-specific paths.
|
||||
func TestSaveOutputFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -545,7 +476,6 @@ func TestSaveOutputFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveOutputFile_InvalidFinalPathTypedError verifies invalid save paths return typed validation errors.
|
||||
func TestSaveOutputFile_InvalidFinalPathTypedError(t *testing.T) {
|
||||
chdirTemp(t)
|
||||
|
||||
@@ -561,19 +491,6 @@ func TestSaveOutputFile_InvalidFinalPathTypedError(t *testing.T) {
|
||||
if ve.Subtype != errs.SubtypeInvalidArgument || ve.Param != "--output" {
|
||||
t.Fatalf("validation details = subtype %q param %q, want %q --output", ve.Subtype, ve.Param, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("errs.ProblemOf returned false")
|
||||
}
|
||||
if p.Category != errs.CategoryValidation {
|
||||
t.Errorf("Category = %q, want %q", p.Category, errs.CategoryValidation)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Problem subtype = %q, want %q", p.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if !errors.Is(err, fileio.ErrPathValidation) {
|
||||
t.Fatalf("expected path-validation cause to be preserved, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newExecuteFactory(t *testing.T) (*cmdutil.Factory, *bytes.Buffer, *httpmock.Registry) {
|
||||
@@ -608,7 +525,6 @@ func runShortcut(t *testing.T, shortcut common.Shortcut, args []string, factory
|
||||
return err
|
||||
}
|
||||
|
||||
// TestWhiteboardQueryExecute_AsRaw verifies raw query execution emits the raw node payload.
|
||||
func TestWhiteboardQueryExecute_AsRaw(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -637,7 +553,6 @@ func TestWhiteboardQueryExecute_AsRaw(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQueryExecute_AsCode verifies code query execution emits extracted diagram source.
|
||||
func TestWhiteboardQueryExecute_AsCode(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
chdirTemp(t)
|
||||
@@ -668,7 +583,6 @@ func TestWhiteboardQueryExecute_AsCode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_EmptyNodes verifies code export handles empty whiteboards.
|
||||
func TestExportWhiteboardCode_EmptyNodes(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -691,7 +605,6 @@ func TestExportWhiteboardCode_EmptyNodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_NoCodeBlocks verifies code export reports whiteboards without code blocks.
|
||||
func TestExportWhiteboardCode_NoCodeBlocks(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -716,7 +629,6 @@ func TestExportWhiteboardCode_NoCodeBlocks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_InvalidSyntaxType verifies unknown syntax types are rejected.
|
||||
func TestExportWhiteboardCode_InvalidSyntaxType(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -746,7 +658,6 @@ func TestExportWhiteboardCode_InvalidSyntaxType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_MultipleCodeBlocks verifies multiple code blocks are exported together.
|
||||
func TestExportWhiteboardCode_MultipleCodeBlocks(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -786,7 +697,6 @@ func TestExportWhiteboardCode_MultipleCodeBlocks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_SingleBlock_PlantUML_DirectOutput verifies direct PlantUML output for a single code block.
|
||||
func TestExportWhiteboardCode_SingleBlock_PlantUML_DirectOutput(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -820,7 +730,6 @@ func TestExportWhiteboardCode_SingleBlock_PlantUML_DirectOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardCode_SingleBlock_Mermaid_DirectOutput verifies direct Mermaid output for a single code block.
|
||||
func TestExportWhiteboardCode_SingleBlock_Mermaid_DirectOutput(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -854,7 +763,6 @@ func TestExportWhiteboardCode_SingleBlock_Mermaid_DirectOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardPreview verifies preview downloads can be written to disk.
|
||||
func TestExportWhiteboardPreview(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -883,7 +791,6 @@ func TestExportWhiteboardPreview(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardRaw_EmptyNodes verifies raw export reports empty whiteboards.
|
||||
func TestExportWhiteboardRaw_EmptyNodes(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -906,7 +813,6 @@ func TestExportWhiteboardRaw_EmptyNodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchWhiteboardNodes_APIError verifies node fetch failures preserve typed API errors.
|
||||
func TestFetchWhiteboardNodes_APIError(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
@@ -936,7 +842,6 @@ func TestFetchWhiteboardNodes_APIError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchWhiteboardNodes_InvalidResponseTypedError verifies malformed node responses become typed invalid-response errors.
|
||||
func TestFetchWhiteboardNodes_InvalidResponseTypedError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -996,474 +901,6 @@ func TestFetchWhiteboardNodes_MissingNodesIsEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_DirectOutput verifies SVG export is printed when no output path is provided.
|
||||
func TestExportWhiteboardSvg_DirectOutput(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
svgContent := `<svg xmlns="http://www.w3.org/2000/svg"><rect width="100" height="100"/></svg>`
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": base64.StdEncoding.EncodeToString([]byte(svgContent)),
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg", "--output_as", "svg"}
|
||||
if err := runShortcut(t, WhiteboardQuery, args, factory, stdout); err != nil {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout.String(), "svg_content") {
|
||||
t.Fatalf("stdout missing svg_content key: %s", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_SaveToFile verifies SVG export is written to the requested file.
|
||||
func TestExportWhiteboardSvg_SaveToFile(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
chdirTemp(t)
|
||||
|
||||
svgContent := `<svg xmlns="http://www.w3.org/2000/svg"><circle cx="50" cy="50" r="40"/></svg>`
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-file/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": base64.StdEncoding.EncodeToString([]byte(svgContent)),
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-file", "--output_as", "svg", "--output", "output", "--overwrite"}
|
||||
if err := runShortcut(t, WhiteboardQuery, args, factory, stdout); err != nil {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile("output.svg")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error: %v", err)
|
||||
}
|
||||
if string(data) != svgContent {
|
||||
t.Fatalf("svg content = %q, want %q", string(data), svgContent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_PrettyOutput verifies pretty output includes inline SVG content.
|
||||
func TestExportWhiteboardSvg_PrettyOutput(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
svgContent := `<svg xmlns="http://www.w3.org/2000/svg"><path d="M0 0L10 10"/></svg>`
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-pretty/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": base64.StdEncoding.EncodeToString([]byte(svgContent)),
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-pretty", "--output_as", "svg", "--format", "pretty"}
|
||||
if err := runShortcut(t, WhiteboardQuery, args, factory, stdout); err != nil {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
|
||||
if got := stdout.String(); !strings.Contains(got, svgContent) {
|
||||
t.Fatalf("stdout = %q, want svg content", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_SaveToFile_PrettyOutput verifies pretty output reports the saved SVG path and size.
|
||||
func TestExportWhiteboardSvg_SaveToFile_PrettyOutput(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
chdirTemp(t)
|
||||
|
||||
svgContent := `<svg xmlns="http://www.w3.org/2000/svg"><ellipse cx="60" cy="40" rx="50" ry="30"/></svg>`
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-file-pretty/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": base64.StdEncoding.EncodeToString([]byte(svgContent)),
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-file-pretty", "--output_as", "svg", "--output", "output", "--overwrite", "--format", "pretty"}
|
||||
if err := runShortcut(t, WhiteboardQuery, args, factory, stdout); err != nil {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
|
||||
if got := stdout.String(); !strings.Contains(got, "SVG saved to output.svg") || !strings.Contains(got, "File size:") {
|
||||
t.Fatalf("stdout = %q, want save summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_SaveToFile_ExistingFileWithoutOverwrite verifies existing SVG outputs require --overwrite.
|
||||
func TestExportWhiteboardSvg_SaveToFile_ExistingFileWithoutOverwrite(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
chdirTemp(t)
|
||||
|
||||
if err := os.WriteFile("output.svg", []byte("existing content"), 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
svgContent := `<svg xmlns="http://www.w3.org/2000/svg"><line x1="0" y1="0" x2="1" y2="1"/></svg>`
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-existing/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": base64.StdEncoding.EncodeToString([]byte(svgContent)),
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-existing", "--output_as", "svg", "--output", "output"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for existing output without overwrite")
|
||||
}
|
||||
|
||||
var ve *errs.ValidationError
|
||||
if !errors.As(err, &ve) {
|
||||
t.Fatalf("error is not *errs.ValidationError: %T (%v)", err, err)
|
||||
}
|
||||
if ve.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Subtype = %q, want %q", ve.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if ve.Param != "--overwrite" {
|
||||
t.Errorf("Param = %q, want %q", ve.Param, "--overwrite")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_HTTP5xx verifies plain HTTP 5xx failures are classified as retryable network errors.
|
||||
func TestExportWhiteboardSvg_HTTP5xx(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-5xx/export",
|
||||
Status: 502,
|
||||
RawBody: []byte("bad gateway"),
|
||||
ContentType: "text/plain",
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-5xx", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 502")
|
||||
}
|
||||
var ne *errs.NetworkError
|
||||
if !errors.As(err, &ne) {
|
||||
t.Fatalf("error is not *errs.NetworkError: %T (%v)", err, err)
|
||||
}
|
||||
if ne.Subtype != errs.SubtypeNetworkServer {
|
||||
t.Errorf("Subtype = %q, want %q", ne.Subtype, errs.SubtypeNetworkServer)
|
||||
}
|
||||
if ne.Code != 502 {
|
||||
t.Errorf("Code = %d, want 502", ne.Code)
|
||||
}
|
||||
if !ne.Retryable {
|
||||
t.Error("expected Retryable = true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_HTTP5xxJSONEnvelopeReturnsAPIError verifies API envelopes take precedence over generic 5xx handling.
|
||||
func TestExportWhiteboardSvg_HTTP5xxJSONEnvelopeReturnsAPIError(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-5xx-json/export",
|
||||
Status: 502,
|
||||
ContentType: "application/json",
|
||||
RawBody: []byte(`{"code":99002,"msg":"export task failed"}`),
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-5xx-json", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 502 JSON envelope")
|
||||
}
|
||||
var apiErr *errs.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("error is not *errs.APIError: %T (%v)", err, err)
|
||||
}
|
||||
var ne *errs.NetworkError
|
||||
if errors.As(err, &ne) {
|
||||
t.Fatalf("expected JSON envelope to win over HTTP 5xx fallback, got *errs.NetworkError: %v", err)
|
||||
}
|
||||
if apiErr.Subtype != errs.SubtypeUnknown {
|
||||
t.Errorf("Subtype = %q, want %q", apiErr.Subtype, errs.SubtypeUnknown)
|
||||
}
|
||||
if apiErr.Code != 99002 {
|
||||
t.Errorf("Code = %d, want 99002", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_HTTP4xx verifies plain HTTP 4xx failures are surfaced as API errors.
|
||||
func TestExportWhiteboardSvg_HTTP4xx(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-403/export",
|
||||
Status: 403,
|
||||
RawBody: []byte("forbidden"),
|
||||
ContentType: "text/plain",
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-403", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 403")
|
||||
}
|
||||
var apiErr *errs.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("error is not *errs.APIError: %T (%v)", err, err)
|
||||
}
|
||||
if apiErr.Subtype != errs.SubtypeUnknown {
|
||||
t.Errorf("Subtype = %q, want %q", apiErr.Subtype, errs.SubtypeUnknown)
|
||||
}
|
||||
if apiErr.Code != 403 {
|
||||
t.Errorf("Code = %d, want 403", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_HTTPNotFoundJSONEnvelopeIsAPIError verifies not-found envelopes preserve the typed API error classification.
|
||||
func TestExportWhiteboardSvg_HTTPNotFoundJSONEnvelopeIsAPIError(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/missing-token-svg/export",
|
||||
Status: 404,
|
||||
ContentType: "application/json",
|
||||
RawBody: []byte(`{"code":99001,"msg":"whiteboard not found"}`),
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "missing-token-svg", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 404 JSON envelope")
|
||||
}
|
||||
var apiErr *errs.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("error is not *errs.APIError: %T (%v)", err, err)
|
||||
}
|
||||
if apiErr.Subtype != errs.SubtypeNotFound {
|
||||
t.Errorf("Subtype = %q, want %q", apiErr.Subtype, errs.SubtypeNotFound)
|
||||
}
|
||||
if apiErr.Code != 99001 {
|
||||
t.Errorf("Code = %d, want 99001", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_HTTPNotFoundPlainText verifies plain-text 404 responses surface as not-found API errors.
|
||||
func TestExportWhiteboardSvg_HTTPNotFoundPlainText(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/missing-token-svg-plain/export",
|
||||
Status: 404,
|
||||
ContentType: "text/plain",
|
||||
RawBody: []byte("whiteboard not found"),
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "missing-token-svg-plain", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 404 plain text response")
|
||||
}
|
||||
|
||||
var apiErr *errs.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("error is not *errs.APIError: %T (%v)", err, err)
|
||||
}
|
||||
if apiErr.Subtype != errs.SubtypeNotFound {
|
||||
t.Errorf("Subtype = %q, want %q", apiErr.Subtype, errs.SubtypeNotFound)
|
||||
}
|
||||
if apiErr.Code != 404 {
|
||||
t.Errorf("Code = %d, want 404", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_InvalidJSON verifies malformed success responses are rejected as invalid responses.
|
||||
func TestExportWhiteboardSvg_InvalidJSON(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-badjson/export",
|
||||
Status: 200,
|
||||
RawBody: []byte("not json at all"),
|
||||
ContentType: "application/json",
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-badjson", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
assertInvalidResponse(t, err)
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_InvalidBody200PlainText verifies plain-text 200 responses are rejected as invalid export responses.
|
||||
func TestExportWhiteboardSvg_InvalidBody200PlainText(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-plain-200/export",
|
||||
Status: 200,
|
||||
RawBody: []byte("not json at all"),
|
||||
ContentType: "text/plain",
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-plain-200", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for plain text success response")
|
||||
}
|
||||
assertInvalidResponse(t, err)
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_NonZeroCode verifies non-zero API codes are returned as typed API errors.
|
||||
func TestExportWhiteboardSvg_NonZeroCode(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-apierr/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 99001,
|
||||
"msg": "whiteboard not found",
|
||||
"data": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-apierr", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-zero code")
|
||||
}
|
||||
var apiErr *errs.APIError
|
||||
if !errors.As(err, &apiErr) {
|
||||
t.Fatalf("error is not *errs.APIError: %T (%v)", err, err)
|
||||
}
|
||||
if apiErr.Code != 99001 {
|
||||
t.Errorf("Code = %d, want 99001", apiErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExportWhiteboardSvg_InvalidBase64 verifies invalid SVG payload encoding is rejected.
|
||||
func TestExportWhiteboardSvg_InvalidBase64(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg-badbase64/export",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"content": "!!!not-valid-base64!!!",
|
||||
"mime_type": "image/svg+xml",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
args := []string{"+query", "--whiteboard-token", "test-token-svg-badbase64", "--output_as", "svg"}
|
||||
err := runShortcut(t, WhiteboardQuery, args, factory, stdout)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
assertInvalidResponse(t, err)
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_Validate_SvgValid verifies svg is accepted as a valid query output format.
|
||||
func TestWhiteboardQuery_Validate_SvgValid(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
chdirTemp(t)
|
||||
|
||||
rt := newTestRuntime(map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"output_as": "svg",
|
||||
}, nil)
|
||||
if err := WhiteboardQuery.Validate(ctx, rt); err != nil {
|
||||
t.Fatalf("expected svg to be valid, got err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardQuery_DryRun_Svg verifies the svg dry-run request uses the export endpoint and body.
|
||||
func TestWhiteboardQuery_DryRun_Svg(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
rt := newTestRuntime(map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"output_as": "svg",
|
||||
}, nil)
|
||||
dryRun := WhiteboardQuery.DryRun(ctx, rt)
|
||||
if dryRun == nil {
|
||||
t.Fatal("DryRun() returned nil for svg")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(dryRun)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var got struct {
|
||||
API []struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
Body map[string]interface{} `json:"body"`
|
||||
} `json:"api"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if len(got.API) != 1 {
|
||||
t.Fatalf("len(api) = %d, want 1", len(got.API))
|
||||
}
|
||||
if got.API[0].Method != "POST" {
|
||||
t.Fatalf("method = %q, want POST", got.API[0].Method)
|
||||
}
|
||||
if got.API[0].URL != "/open-apis/board/v1/whiteboards/test...-123/export" {
|
||||
t.Fatalf("url = %q", got.API[0].URL)
|
||||
}
|
||||
if got.API[0].Body["export_type"] != "svg" {
|
||||
t.Fatalf("body = %#v, want export_type=svg", got.API[0].Body)
|
||||
}
|
||||
if _, ok := got.API[0].Params["export_type"]; ok {
|
||||
t.Fatalf("params should not include export_type, got %#v", got.API[0].Params)
|
||||
}
|
||||
}
|
||||
|
||||
// assertInvalidResponse verifies an error is classified as a typed invalid-response failure.
|
||||
func assertInvalidResponse(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
|
||||
@@ -17,21 +17,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// FormatRaw sends raw whiteboard node JSON to the create-nodes API.
|
||||
FormatRaw = "raw"
|
||||
// FormatPlantUML sends PlantUML source through the diagram import API.
|
||||
FormatRaw = "raw"
|
||||
FormatPlantUML = "plantuml"
|
||||
// FormatMermaid sends Mermaid source through the diagram import API.
|
||||
FormatMermaid = "mermaid"
|
||||
// FormatSVG sends SVG source through the diagram import API.
|
||||
FormatSVG = "svg"
|
||||
FormatMermaid = "mermaid"
|
||||
)
|
||||
|
||||
var formatCodeMap = map[string]int{
|
||||
FormatRaw: 0,
|
||||
FormatPlantUML: 1,
|
||||
FormatMermaid: 2,
|
||||
FormatSVG: 3,
|
||||
}
|
||||
|
||||
var wbUpdateScopes = []string{"board:whiteboard:node:create"}
|
||||
@@ -41,14 +35,9 @@ var wbUpdateFlags = []common.Flag{
|
||||
{Name: "whiteboard-token", Desc: "whiteboard token of the whiteboard to update. You will need edit permission to update the whiteboard.", Required: true},
|
||||
{Name: "overwrite", Desc: "overwrite the whiteboard content, delete all existing content before update. Default is false.", Required: false, Type: "bool"},
|
||||
{Name: "source", Desc: "Input whiteboard data.", Required: true, Input: []string{common.Stdin, common.File}},
|
||||
{Name: "input_format", Desc: "format of input data: raw | plantuml | mermaid | svg. Default is raw.", Required: false},
|
||||
{Name: "input_format", Desc: "format of input data: raw | plantuml | mermaid. Default is raw.", Required: false},
|
||||
}
|
||||
|
||||
// wbUpdateValidate validates the whiteboard update command arguments.
|
||||
//
|
||||
// It checks the whiteboard token and idempotent token for dangerous control
|
||||
// characters, enforces a minimum length for a non-empty idempotent token, and
|
||||
// ensures the input format is one of raw, plantuml, mermaid, or svg.
|
||||
func wbUpdateValidate(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
// 检查 token 是否包含控制字符(空字符串下自动跳过了)
|
||||
if err := common.RejectDangerousCharsTyped("--whiteboard-token", runtime.Str("whiteboard-token")); err != nil {
|
||||
@@ -64,8 +53,8 @@ func wbUpdateValidate(ctx context.Context, runtime *common.RuntimeContext) error
|
||||
|
||||
// 检查 --input_format 标志
|
||||
format := getFormat(runtime)
|
||||
if format != FormatRaw && format != FormatPlantUML && format != FormatMermaid && format != FormatSVG {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--input_format must be one of: raw | plantuml | mermaid | svg").WithParam("--input_format")
|
||||
if format != FormatRaw && format != FormatPlantUML && format != FormatMermaid {
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--input_format must be one of: raw | plantuml | mermaid").WithParam("--input_format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -79,8 +68,6 @@ func getFormat(runtime *common.RuntimeContext) string {
|
||||
return format
|
||||
}
|
||||
|
||||
// wbUpdateDryRun describes the HTTP request used to update a whiteboard.
|
||||
// It returns a failure description when source is missing or cannot be parsed.
|
||||
func wbUpdateDryRun(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
// 读取输入内容
|
||||
input := runtime.Str("source")
|
||||
@@ -104,7 +91,7 @@ func wbUpdateDryRun(ctx context.Context, runtime *common.RuntimeContext) *common
|
||||
Overwrite: overwrite,
|
||||
}
|
||||
desc.POST(fmt.Sprintf("/open-apis/board/v1/whiteboards/%s/nodes", common.MaskToken(url.PathEscape(token)))).Body(reqBody).Desc("create all nodes of the whiteboard.")
|
||||
case FormatPlantUML, FormatMermaid, FormatSVG:
|
||||
case FormatPlantUML, FormatMermaid:
|
||||
syntaxType := formatCodeMap[format]
|
||||
reqBody := plantumlCreateReq{
|
||||
PlantUmlCode: input,
|
||||
@@ -119,10 +106,6 @@ func wbUpdateDryRun(ctx context.Context, runtime *common.RuntimeContext) *common
|
||||
return desc
|
||||
}
|
||||
|
||||
// wbUpdateExecute updates a whiteboard from the supplied source input.
|
||||
// It requires --source and dispatches to the raw node update path for raw input
|
||||
// or the diagram import path for PlantUML, Mermaid, and SVG input.
|
||||
// It returns an error if the source is missing or the input format is unsupported.
|
||||
func wbUpdateExecute(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
token := runtime.Str("whiteboard-token")
|
||||
overwrite := runtime.Bool("overwrite")
|
||||
@@ -137,17 +120,15 @@ func wbUpdateExecute(ctx context.Context, runtime *common.RuntimeContext) error
|
||||
switch format {
|
||||
case FormatRaw:
|
||||
return updateWhiteboardByRawNodes(ctx, runtime, token, []byte(input), overwrite, idempotentToken)
|
||||
case FormatPlantUML, FormatMermaid, FormatSVG:
|
||||
case FormatPlantUML, FormatMermaid:
|
||||
return updateWhiteboardByCode(ctx, runtime, token, []byte(input), format, overwrite, idempotentToken)
|
||||
default:
|
||||
return errs.NewValidationError(errs.SubtypeInvalidArgument, "unsupported format: %s", format).WithParam("--input_format")
|
||||
}
|
||||
}
|
||||
|
||||
// WhiteboardUpdateDescription describes the whiteboard update shortcut.
|
||||
const WhiteboardUpdateDescription = "Update an existing whiteboard in lark document with mermaid, plantuml or whiteboard dsl. refer to lark-whiteboard skill for more details."
|
||||
|
||||
// WhiteboardUpdate registers the `whiteboard +update` shortcut.
|
||||
var WhiteboardUpdate = common.Shortcut{
|
||||
Service: "whiteboard",
|
||||
Command: "+update",
|
||||
|
||||
@@ -6,7 +6,6 @@ package whiteboard
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TestWhiteboardUpdate_Validate verifies update flag validation for supported input formats.
|
||||
func TestWhiteboardUpdate_Validate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -55,15 +53,6 @@ func TestWhiteboardUpdate_Validate(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid: svg format",
|
||||
flags: map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"input_format": "svg",
|
||||
"source": "<svg/>",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid: with idempotent-token",
|
||||
flags: map[string]string{
|
||||
@@ -128,26 +117,25 @@ func TestWhiteboardUpdate_Validate_TypedErrors(t *testing.T) {
|
||||
"idempotent-token": "short",
|
||||
"source": "{}",
|
||||
}, nil)
|
||||
assertValidationParam(t, wbUpdateValidate(ctx, rt), "--idempotent-token", false)
|
||||
assertValidationParam(t, wbUpdateValidate(ctx, rt), "--idempotent-token")
|
||||
})
|
||||
|
||||
t.Run("bad input_format", func(t *testing.T) {
|
||||
rt := newTestRuntime(map[string]string{
|
||||
"whiteboard-token": "t",
|
||||
"input_format": "png",
|
||||
"input_format": "svg",
|
||||
"source": "{}",
|
||||
}, nil)
|
||||
assertValidationParam(t, wbUpdateValidate(ctx, rt), "--input_format", false)
|
||||
assertValidationParam(t, wbUpdateValidate(ctx, rt), "--input_format")
|
||||
})
|
||||
|
||||
t.Run("malformed source json", func(t *testing.T) {
|
||||
_, err, _ := parseWBcliNodes([]byte("not-json"))
|
||||
assertValidationParam(t, err, "--source", true)
|
||||
assertValidationParam(t, err, "--source")
|
||||
})
|
||||
}
|
||||
|
||||
// assertValidationParam verifies a validation error carries the expected flag param.
|
||||
func assertValidationParam(t *testing.T, err error, wantParam string, wantJSONCause bool) {
|
||||
func assertValidationParam(t *testing.T, err error, wantParam string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
@@ -162,25 +150,8 @@ func assertValidationParam(t *testing.T, err error, wantParam string, wantJSONCa
|
||||
if ve.Param != wantParam {
|
||||
t.Errorf("Param = %q, want %q", ve.Param, wantParam)
|
||||
}
|
||||
p, ok := errs.ProblemOf(err)
|
||||
if !ok {
|
||||
t.Fatalf("errs.ProblemOf returned false")
|
||||
}
|
||||
if p.Category != errs.CategoryValidation {
|
||||
t.Errorf("Category = %q, want %q", p.Category, errs.CategoryValidation)
|
||||
}
|
||||
if p.Subtype != errs.SubtypeInvalidArgument {
|
||||
t.Errorf("Problem subtype = %q, want %q", p.Subtype, errs.SubtypeInvalidArgument)
|
||||
}
|
||||
if wantJSONCause {
|
||||
var syntaxErr *json.SyntaxError
|
||||
if !errors.As(err, &syntaxErr) {
|
||||
t.Fatalf("expected json syntax cause to be preserved, err=%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetFormat verifies input format defaults and explicit format selection.
|
||||
func TestGetFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -209,11 +180,6 @@ func TestGetFormat(t *testing.T) {
|
||||
flagVal: FormatMermaid,
|
||||
expected: FormatMermaid,
|
||||
},
|
||||
{
|
||||
name: "svg returns svg",
|
||||
flagVal: FormatSVG,
|
||||
expected: FormatSVG,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -227,7 +193,6 @@ func TestGetFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdate_ShortcutRegistration verifies the shortcut metadata for update commands.
|
||||
func TestWhiteboardUpdate_ShortcutRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -248,7 +213,6 @@ func TestWhiteboardUpdate_ShortcutRegistration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortcutsIncludesExpectedCommands verifies the whiteboard shortcut registry includes query and update.
|
||||
func TestShortcutsIncludesExpectedCommands(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -273,7 +237,6 @@ func TestShortcutsIncludesExpectedCommands(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseWBcliNodes verifies whiteboard CLI output parsing for raw and wrapped node payloads.
|
||||
func TestParseWBcliNodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -322,7 +285,6 @@ func TestParseWBcliNodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWBUpdateDryRun verifies dry-run requests for the supported whiteboard update formats.
|
||||
func TestWBUpdateDryRun(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -355,14 +317,6 @@ func TestWBUpdateDryRun(t *testing.T) {
|
||||
"source": "graph TD\nA-->B",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dry run svg format",
|
||||
flags: map[string]string{
|
||||
"whiteboard-token": "test-token-123",
|
||||
"input_format": "svg",
|
||||
"source": "<svg/>",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -408,7 +362,6 @@ func runUpdateShortcut(t *testing.T, shortcut common.Shortcut, args []string, fa
|
||||
return err
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawFormat verifies raw node updates call the raw nodes endpoint.
|
||||
func TestWhiteboardUpdateExecute_RawFormat(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -432,7 +385,6 @@ func TestWhiteboardUpdateExecute_RawFormat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_PlantUMLFormat verifies PlantUML updates use the diagram import endpoint.
|
||||
func TestWhiteboardUpdateExecute_PlantUMLFormat(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -458,7 +410,6 @@ Bob -> Alice : hello
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_PlantUMLInvalidResponse verifies missing node IDs are treated as invalid responses.
|
||||
func TestWhiteboardUpdateExecute_PlantUMLInvalidResponse(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -480,7 +431,6 @@ Bob -> Alice : hello
|
||||
assertInvalidResponse(t, err)
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_MermaidFormat verifies Mermaid updates use the diagram import endpoint.
|
||||
func TestWhiteboardUpdateExecute_MermaidFormat(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -505,44 +455,6 @@ A-->B`
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_SVGFormat verifies svg update requests use syntax_type=3 and send the source payload.
|
||||
func TestWhiteboardUpdateExecute_SVGFormat(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
// SVG shares the /nodes/plantuml endpoint with plantuml/mermaid via syntax_type=3.
|
||||
stub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/board/v1/whiteboards/test-token-svg/nodes/plantuml",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": map[string]interface{}{
|
||||
"node_id": "node1",
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(stub)
|
||||
|
||||
source := `<svg xmlns="http://www.w3.org/2000/svg"/>`
|
||||
args := []string{"+update", "--whiteboard-token", "test-token-svg", "--input_format", "svg", "--source", source}
|
||||
if err := runUpdateShortcut(t, WhiteboardUpdate, args, factory, stdout); err != nil {
|
||||
t.Fatalf("err=%v", err)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(stub.CapturedBody, &body); err != nil {
|
||||
t.Fatalf("unmarshal captured body: %v\nraw=%s", err, string(stub.CapturedBody))
|
||||
}
|
||||
|
||||
if got := body["syntax_type"]; got != float64(3) {
|
||||
t.Fatalf("syntax_type = %#v, want 3; body=%s", got, string(stub.CapturedBody))
|
||||
}
|
||||
if got := body["plant_uml_code"]; got != source {
|
||||
t.Fatalf("plant_uml_code = %#v, want %q; body=%s", got, source, string(stub.CapturedBody))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawInvalidResponse verifies malformed raw update responses are rejected.
|
||||
func TestWhiteboardUpdateExecute_RawInvalidResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -582,7 +494,6 @@ func TestWhiteboardUpdateExecute_RawInvalidResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawWithIdempotent verifies raw updates pass through the idempotency token.
|
||||
func TestWhiteboardUpdateExecute_RawWithIdempotent(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -607,7 +518,6 @@ func TestWhiteboardUpdateExecute_RawWithIdempotent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawFormatWithRawNodes verifies raw-node payloads are forwarded without DSL wrapping.
|
||||
func TestWhiteboardUpdateExecute_RawFormatWithRawNodes(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -631,7 +541,6 @@ func TestWhiteboardUpdateExecute_RawFormatWithRawNodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawAPIError verifies raw update API failures preserve typed error metadata and hints.
|
||||
func TestWhiteboardUpdateExecute_RawAPIError(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -668,7 +577,6 @@ func TestWhiteboardUpdateExecute_RawAPIError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_PlantUMLAPIError verifies diagram update API failures preserve typed error metadata.
|
||||
func TestWhiteboardUpdateExecute_PlantUMLAPIError(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -699,7 +607,6 @@ invalid
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_WithOverwrite verifies diagram updates send overwrite=true when requested.
|
||||
func TestWhiteboardUpdateExecute_WithOverwrite(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
@@ -724,7 +631,6 @@ A-->B`
|
||||
}
|
||||
}
|
||||
|
||||
// TestWhiteboardUpdateExecute_RawWithOverwrite verifies raw updates send overwrite=true when requested.
|
||||
func TestWhiteboardUpdateExecute_RawWithOverwrite(t *testing.T) {
|
||||
factory, stdout, reg := newUpdateExecuteFactory(t)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user