mirror of
https://github.com/larksuite/cli.git
synced 2026-07-03 22:24:31 +08:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
776ee686ff | ||
|
|
4da6d610e2 | ||
|
|
3f4352d50c | ||
|
|
543a8365d6 | ||
|
|
0192cee859 | ||
|
|
18e227f281 | ||
|
|
7e9beec422 | ||
|
|
462d38e8f7 | ||
|
|
e4d263948c | ||
|
|
11191df703 | ||
|
|
e23b3a8dc6 | ||
|
|
f3699298aa | ||
|
|
018eeb6414 | ||
|
|
3e5dc3262f | ||
|
|
c13644a247 | ||
|
|
cb301a3d1a | ||
|
|
04e3a28529 | ||
|
|
e02c442aea | ||
|
|
fbed6beac3 | ||
|
|
e15aef922e | ||
|
|
ccc27ce417 | ||
|
|
24e0bb38eb | ||
|
|
9057299430 | ||
|
|
9e891b758e | ||
|
|
293a9f896f | ||
|
|
0a0cdc8879 | ||
|
|
67e51ec8d7 | ||
|
|
5943a20e2b | ||
|
|
cd666422ac | ||
|
|
9acd121259 | ||
|
|
1262aac480 | ||
|
|
abb02cd46c | ||
|
|
db7d3cb64d | ||
|
|
5134719da9 | ||
|
|
5a0e1d3dd9 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -36,3 +36,5 @@ tests/mail/reports/
|
||||
internal/registry/meta_data.json
|
||||
cmd/api/download.bin
|
||||
app.log
|
||||
/sidecar-server-demo
|
||||
/server-demo
|
||||
|
||||
63
CHANGELOG.md
63
CHANGELOG.md
@@ -2,6 +2,66 @@
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [v1.0.17] - 2026-04-22
|
||||
|
||||
### Features
|
||||
|
||||
- **im**: Use `Content-Disposition` filename when downloading message resources (#536)
|
||||
- **drive**: Add `+apply-permission` to request doc access (#588)
|
||||
- Support record share link (#466)
|
||||
- **whiteboard**: Add image support to `whiteboard-cli` skill (#553)
|
||||
- **cmdutil**: Add `X-Cli-Build` header for CLI build classification (#596)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- **base**: Add default-table follow-up hint to `base-create` (#600)
|
||||
- Skip flag-completion registration outside completion path (#598)
|
||||
- Add `record-share-link-create` in `SKILL.md` (#597)
|
||||
- **mail**: Remove leftover conflict marker in skill docs (#594)
|
||||
|
||||
### Documentation
|
||||
|
||||
- **drive**: Clarify that comment listing defaults to unresolved comments only (#609)
|
||||
- **doc**: Fix `--markdown` examples that teach literal `\n` (#602)
|
||||
- **mail**: Remove `get_signatures` from skill reference, exposed via `+signature` instead (#545)
|
||||
|
||||
## [v1.0.16] - 2026-04-21
|
||||
|
||||
### Features
|
||||
|
||||
- **mail**: Support large email attachments (#537)
|
||||
- **mail**: Add draft preview URL to draft operations (#438)
|
||||
- **doc**: Add pre-write semantic warnings to `docs +update` (#569)
|
||||
- **doc**: Add `--selection-with-ellipsis` position flag to `+media-insert` (#335)
|
||||
- **calendar**: Support event share link and error details (#583)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- **doc**: Preserve round-trip formatting in `+fetch` output (#469)
|
||||
- **docs**: Validate `--selection-by-title` format early (#256)
|
||||
- **whiteboard**: Register `+media-upload` shortcut and add whiteboard parent type
|
||||
|
||||
### Refactor
|
||||
|
||||
- Split `Execute` into `Build` + `Execute` with explicit IO and keychain injection (#371)
|
||||
- **auth**: Simplify scope reporting in login flow (#582)
|
||||
|
||||
## [v1.0.15] - 2026-04-20
|
||||
|
||||
### Features
|
||||
|
||||
- **sheets**: Add float image shortcuts (#494)
|
||||
- **approval**: Document `remind` and `initiated` methods in skill (#554)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- **base**: Preserve attachment metadata on base uploads (#563)
|
||||
- **base**: Fix role view and record default permission on edit (#530)
|
||||
- **sheets**: Normalize single-cell range in `+set-style` and `+batch-set-style` (#548)
|
||||
- **im**: Cap `basic_batch` user_ids at 10 per API limit (#551)
|
||||
- **install**: Refine install wizard messages (#529)
|
||||
- **whiteboard**: Deprecate old `lark-whiteboard-cli` skill (#547)
|
||||
|
||||
## [v1.0.14] - 2026-04-17
|
||||
|
||||
### Features
|
||||
@@ -404,6 +464,9 @@ Bundled AI agent skills for intelligent assistance:
|
||||
- Bilingual documentation (English & Chinese).
|
||||
- CI/CD pipelines: linting, testing, coverage reporting, and automated releases.
|
||||
|
||||
[v1.0.17]: https://github.com/larksuite/cli/releases/tag/v1.0.17
|
||||
[v1.0.16]: https://github.com/larksuite/cli/releases/tag/v1.0.16
|
||||
[v1.0.15]: https://github.com/larksuite/cli/releases/tag/v1.0.15
|
||||
[v1.0.14]: https://github.com/larksuite/cli/releases/tag/v1.0.14
|
||||
[v1.0.13]: https://github.com/larksuite/cli/releases/tag/v1.0.13
|
||||
[v1.0.12]: https://github.com/larksuite/cli/releases/tag/v1.0.12
|
||||
|
||||
@@ -57,6 +57,10 @@ func normalisePath(raw string) string {
|
||||
|
||||
// NewCmdApi creates the api command. If runF is non-nil it is called instead of apiRun (test hook).
|
||||
func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command {
|
||||
return NewCmdApiWithContext(context.Background(), f, runF)
|
||||
}
|
||||
|
||||
func NewCmdApiWithContext(ctx context.Context, f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command {
|
||||
opts := &APIOptions{Factory: f}
|
||||
var asStr string
|
||||
|
||||
@@ -79,7 +83,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command
|
||||
|
||||
cmd.Flags().StringVar(&opts.Params, "params", "", "query parameters JSON (supports - for stdin)")
|
||||
cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)")
|
||||
cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)")
|
||||
cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr)
|
||||
cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses")
|
||||
cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages")
|
||||
cmd.Flags().IntVar(&opts.PageSize, "page-size", 0, "page size (0 = use API default)")
|
||||
@@ -96,10 +100,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command
|
||||
}
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
_ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
_ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
|
||||
|
||||
@@ -180,6 +180,24 @@ func TestApiValidArgsFunction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCmdApi_StrictModeHidesAsFlag(t *testing.T) {
|
||||
f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{
|
||||
AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2,
|
||||
})
|
||||
|
||||
cmd := NewCmdApi(f, nil)
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("expected --as flag to be hidden in strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "bot" {
|
||||
t.Fatalf("default value = %q, want %q", got, "bot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiCmd_PageLimitDefault(t *testing.T) {
|
||||
f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{
|
||||
AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu,
|
||||
|
||||
@@ -72,7 +72,7 @@ browser. Run it in the background and retrieve the verification URL from its out
|
||||
cmd.Flags().BoolVar(&opts.NoWait, "no-wait", false, "initiate device authorization and return immediately; use --device-code to complete")
|
||||
cmd.Flags().StringVar(&opts.DeviceCode, "device-code", "", "poll and complete authorization with a device code from a previous --no-wait call")
|
||||
|
||||
_ = cmd.RegisterFlagCompletionFunc("domain", func(_ *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, "domain", func(_ *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
return completeDomain(toComplete), cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ type loginMsg struct {
|
||||
ScopeHint string
|
||||
RequestedScopes string
|
||||
NewlyGrantedScopes string
|
||||
MissingScopes string
|
||||
NoScopes string
|
||||
StatusHint string
|
||||
|
||||
@@ -59,14 +58,13 @@ var loginMsgZh = &loginMsg{
|
||||
|
||||
OpenURL: "在浏览器中打开以下链接进行认证:\n\n",
|
||||
WaitingAuth: "等待用户授权...",
|
||||
AuthSuccess: "授权已完成,正在获取用户信息并校验授权结果...",
|
||||
AuthSuccess: "已收到授权确认,正在获取用户信息并校验授权结果...",
|
||||
LoginSuccess: "授权成功! 用户: %s (%s)",
|
||||
AuthorizedUser: "当前授权账号: %s (%s)",
|
||||
ScopeMismatch: "授权结果异常:以下请求 scopes 未被授予: %s",
|
||||
ScopeMismatch: "授权结果异常: 以下请求 scopes 未被授予: %s",
|
||||
ScopeHint: "以上结果是本次授权请求用户最终确认后的结果,请勿持续重试;Scopes 未授予的原因是多样的,如 scope 被禁用;具体原因已通过授权页提示用户。可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
RequestedScopes: " 本次请求 scopes: %s\n",
|
||||
NewlyGrantedScopes: " 本次新授予 scopes: %s\n",
|
||||
MissingScopes: " 本次未授予 scopes: %s\n",
|
||||
NoScopes: "(空)",
|
||||
StatusHint: "可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
|
||||
@@ -95,14 +93,13 @@ var loginMsgEn = &loginMsg{
|
||||
|
||||
OpenURL: "Open this URL in your browser to authenticate:\n\n",
|
||||
WaitingAuth: "Waiting for user authorization...",
|
||||
AuthSuccess: "Authorization completed, fetching user info and validating granted scopes...",
|
||||
AuthSuccess: "Authorization confirmed, fetching user info and validating granted scopes...",
|
||||
LoginSuccess: "Authorization successful! User: %s (%s)",
|
||||
AuthorizedUser: "Authorized account: %s (%s)",
|
||||
ScopeMismatch: "authorization result is abnormal: these requested scopes were not granted: %s",
|
||||
ScopeHint: "The result above is the user's final confirmation for this authorization request. Do not retry continuously. Scopes may be not granted for various reasons, such as a scope being disabled. The specific reason has already been shown to the user on the authorization page. Run `lark-cli auth status` to inspect all scopes currently granted to the account.",
|
||||
RequestedScopes: " Requested scopes: %s\n",
|
||||
NewlyGrantedScopes: " Newly granted scopes: %s\n",
|
||||
MissingScopes: " Not granted scopes: %s\n",
|
||||
NoScopes: "(none)",
|
||||
StatusHint: "Run `lark-cli auth status` to inspect all scopes currently granted to the account.",
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ func emptyIfNil(s []string) []string {
|
||||
return s
|
||||
}
|
||||
|
||||
// writeLoginScopeBreakdown renders the requested/newly granted/missing scope
|
||||
// writeLoginScopeBreakdown renders the requested/newly granted scope
|
||||
// breakdown to stderr.
|
||||
func writeLoginScopeBreakdown(errOut *cmdutil.IOStreams, msg *loginMsg, summary *loginScopeSummary) {
|
||||
if summary == nil {
|
||||
@@ -136,7 +136,6 @@ func writeLoginScopeBreakdown(errOut *cmdutil.IOStreams, msg *loginMsg, summary
|
||||
}
|
||||
fmt.Fprintf(errOut.ErrOut, msg.RequestedScopes, formatScopeList(summary.Requested, msg.NoScopes))
|
||||
fmt.Fprintf(errOut.ErrOut, msg.NewlyGrantedScopes, formatScopeList(summary.NewlyGranted, msg.NoScopes))
|
||||
fmt.Fprintf(errOut.ErrOut, msg.MissingScopes, formatScopeList(summary.Missing, msg.NoScopes))
|
||||
}
|
||||
|
||||
// writeLoginSuccess emits the successful login payload in either JSON or text
|
||||
|
||||
@@ -363,7 +363,7 @@ func TestWriteLoginSuccess_JSONIncludesScopeDiff(t *testing.T) {
|
||||
func TestHandleLoginScopeIssue_NonJSONAlignsWithLoginSuccess(t *testing.T) {
|
||||
f, _, stderr, _ := cmdutil.TestFactory(t, nil)
|
||||
err := handleLoginScopeIssue(&LoginOptions{}, getLoginMsg("zh"), f, &loginScopeIssue{
|
||||
Message: "授权结果异常:以下请求 scopes 未被授予: im:message:send",
|
||||
Message: "授权结果异常: 以下请求 scopes 未被授予: im:message:send",
|
||||
Hint: "以上结果是本次授权请求用户最终确认后的结果,请勿持续重试;Scopes 未授予的原因是多样的,如 scope 被禁用;具体原因已通过授权页提示用户。可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
Summary: &loginScopeSummary{
|
||||
Requested: []string{"im:message:send"},
|
||||
@@ -376,11 +376,10 @@ func TestHandleLoginScopeIssue_NonJSONAlignsWithLoginSuccess(t *testing.T) {
|
||||
}
|
||||
got := stderr.String()
|
||||
for _, want := range []string{
|
||||
"授权结果异常:以下请求 scopes 未被授予: im:message:send",
|
||||
"授权结果异常: 以下请求 scopes 未被授予: im:message:send",
|
||||
"当前授权账号: tester (ou_user)",
|
||||
"本次请求 scopes: im:message:send",
|
||||
"本次新授予 scopes: (空)",
|
||||
"本次未授予 scopes: im:message:send",
|
||||
"以上结果是本次授权请求用户最终确认后的结果,请勿持续重试",
|
||||
"scope 被禁用",
|
||||
"lark-cli auth status",
|
||||
@@ -395,6 +394,9 @@ func TestHandleLoginScopeIssue_NonJSONAlignsWithLoginSuccess(t *testing.T) {
|
||||
if strings.Contains(got, "授权成功") {
|
||||
t.Fatalf("stderr should not contain success wording, got:\n%s", got)
|
||||
}
|
||||
if strings.Contains(got, "本次未授予 scopes:") {
|
||||
t.Fatalf("stderr should not duplicate missing scopes, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleLoginScopeIssue_JSONAlignsWithLoginSuccess(t *testing.T) {
|
||||
@@ -472,10 +474,10 @@ func TestWriteLoginSuccess_TextOutputScenarios(t *testing.T) {
|
||||
"授权成功! 用户: tester (ou_user)",
|
||||
"本次请求 scopes: im:message:send im:message:reply",
|
||||
"本次新授予 scopes: im:message:send",
|
||||
"本次未授予 scopes: (空)",
|
||||
"可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
},
|
||||
expectedAbsent: []string{
|
||||
"本次未授予 scopes:",
|
||||
"最终已授权 scopes:",
|
||||
"已有 scopes:",
|
||||
},
|
||||
@@ -490,10 +492,10 @@ func TestWriteLoginSuccess_TextOutputScenarios(t *testing.T) {
|
||||
expectedPresent: []string{
|
||||
"本次请求 scopes: im:message:send",
|
||||
"本次新授予 scopes: (空)",
|
||||
"本次未授予 scopes: (空)",
|
||||
"可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
},
|
||||
expectedAbsent: []string{
|
||||
"本次未授予 scopes:",
|
||||
"最终已授权 scopes:",
|
||||
"已有 scopes:",
|
||||
},
|
||||
@@ -508,9 +510,9 @@ func TestWriteLoginSuccess_TextOutputScenarios(t *testing.T) {
|
||||
expectedPresent: []string{
|
||||
"本次请求 scopes: im:message:send im:message:reply",
|
||||
"本次新授予 scopes: (空)",
|
||||
"本次未授予 scopes: im:message:send",
|
||||
},
|
||||
expectedAbsent: []string{
|
||||
"本次未授予 scopes:",
|
||||
"已有 scopes:",
|
||||
"最终已授权 scopes:",
|
||||
"可执行 `lark-cli auth status` 查看账号当前已授予的全部 scopes;",
|
||||
@@ -619,10 +621,9 @@ func TestAuthLoginRun_MissingRequestedScopeAlignsWithLoginSuccess(t *testing.T)
|
||||
}
|
||||
got := stderr.String()
|
||||
for _, want := range []string{
|
||||
"授权结果异常:以下请求 scopes 未被授予: im:message:send",
|
||||
"授权结果异常: 以下请求 scopes 未被授予: im:message:send",
|
||||
"当前授权账号: tester (ou_user)",
|
||||
"本次请求 scopes: im:message:send",
|
||||
"本次未授予 scopes: im:message:send",
|
||||
"以上结果是本次授权请求用户最终确认后的结果,请勿持续重试",
|
||||
"scope 被禁用",
|
||||
"lark-cli auth status",
|
||||
@@ -637,6 +638,9 @@ func TestAuthLoginRun_MissingRequestedScopeAlignsWithLoginSuccess(t *testing.T)
|
||||
if strings.Contains(got, "OK: 授权成功") {
|
||||
t.Fatalf("stderr should not contain success prefix when scopes are missing, got:\n%s", got)
|
||||
}
|
||||
if strings.Contains(got, "本次未授予 scopes:") {
|
||||
t.Fatalf("stderr should not duplicate missing scopes, got:\n%s", got)
|
||||
}
|
||||
if strings.Contains(got, "ERROR:") {
|
||||
t.Fatalf("stderr should not contain error prefix, got:\n%s", got)
|
||||
}
|
||||
@@ -777,13 +781,15 @@ func TestWriteLoginSuccess_TextOutputEnglishIncludesStatusHintWhenNoMissingScope
|
||||
"Authorization successful! User: tester (ou_user)",
|
||||
"Requested scopes: im:message:send",
|
||||
"Newly granted scopes: im:message:send",
|
||||
"Not granted scopes: (none)",
|
||||
"Run `lark-cli auth status` to inspect all scopes currently granted to the account.",
|
||||
} {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Fatalf("stderr missing %q, got:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
if strings.Contains(got, "Not granted scopes:") {
|
||||
t.Fatalf("stderr should not contain not granted scopes, got:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthLoginRun_DeviceCodeTokenNilCleansScopeCache(t *testing.T) {
|
||||
|
||||
129
cmd/build.go
Normal file
129
cmd/build.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/larksuite/cli/cmd/api"
|
||||
"github.com/larksuite/cli/cmd/auth"
|
||||
"github.com/larksuite/cli/cmd/completion"
|
||||
cmdconfig "github.com/larksuite/cli/cmd/config"
|
||||
"github.com/larksuite/cli/cmd/doctor"
|
||||
"github.com/larksuite/cli/cmd/profile"
|
||||
"github.com/larksuite/cli/cmd/schema"
|
||||
"github.com/larksuite/cli/cmd/service"
|
||||
cmdupdate "github.com/larksuite/cli/cmd/update"
|
||||
"github.com/larksuite/cli/internal/build"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/keychain"
|
||||
"github.com/larksuite/cli/shortcuts"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// BuildOption configures optional aspects of the command tree construction.
|
||||
type BuildOption func(*buildConfig)
|
||||
|
||||
type buildConfig struct {
|
||||
streams *cmdutil.IOStreams
|
||||
keychain keychain.KeychainAccess
|
||||
globals GlobalOptions
|
||||
}
|
||||
|
||||
// WithIO sets the IO streams for the CLI by wrapping raw reader/writers.
|
||||
// Terminal detection is delegated to cmdutil.NewIOStreams.
|
||||
func WithIO(in io.Reader, out, errOut io.Writer) BuildOption {
|
||||
return func(c *buildConfig) {
|
||||
c.streams = cmdutil.NewIOStreams(in, out, errOut)
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeychain sets the secret storage backend. If not provided, the platform keychain is used.
|
||||
func WithKeychain(kc keychain.KeychainAccess) BuildOption {
|
||||
return func(c *buildConfig) {
|
||||
c.keychain = kc
|
||||
}
|
||||
}
|
||||
|
||||
// HideProfile sets the visibility policy for the root-level --profile flag.
|
||||
// When hide is true the flag stays registered (so existing invocations still
|
||||
// parse) but is omitted from help and shell completion. Typically called as
|
||||
// HideProfile(isSingleAppMode()).
|
||||
func HideProfile(hide bool) BuildOption {
|
||||
return func(c *buildConfig) {
|
||||
c.globals.HideProfile = hide
|
||||
}
|
||||
}
|
||||
|
||||
// Build constructs the full command tree without executing.
|
||||
// Returns only the cobra.Command; Factory is internal.
|
||||
// Use Execute for the standard production entry point.
|
||||
func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) *cobra.Command {
|
||||
_, rootCmd := buildInternal(ctx, inv, opts...)
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
// buildInternal is a pure assembly function: it wires the command tree from
|
||||
// inv and BuildOptions alone. Any state-dependent decision (disk, network,
|
||||
// env) belongs in the caller and must be threaded in via BuildOption.
|
||||
func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) {
|
||||
// cfg.globals.Profile is left zero here; it's bound to the --profile
|
||||
// flag in RegisterGlobalFlags and filled by cobra's parse step.
|
||||
cfg := &buildConfig{}
|
||||
for _, o := range opts {
|
||||
if o != nil {
|
||||
o(cfg)
|
||||
}
|
||||
}
|
||||
// Default streams when WithIO is not supplied so the root command's
|
||||
// SetIn/Out/Err calls below don't deref nil. NewDefault also normalizes
|
||||
// partial streams internally; keep both in sync so cfg.streams reflects
|
||||
// the same values the Factory ends up using.
|
||||
if cfg.streams == nil {
|
||||
cfg.streams = cmdutil.SystemIO()
|
||||
}
|
||||
|
||||
f := cmdutil.NewDefault(cfg.streams, inv)
|
||||
if cfg.keychain != nil {
|
||||
f.Keychain = cfg.keychain
|
||||
}
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "lark-cli",
|
||||
Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls",
|
||||
Long: rootLong,
|
||||
Version: build.Version,
|
||||
}
|
||||
|
||||
rootCmd.SetContext(ctx)
|
||||
rootCmd.SetIn(cfg.streams.In)
|
||||
rootCmd.SetOut(cfg.streams.Out)
|
||||
rootCmd.SetErr(cfg.streams.ErrOut)
|
||||
|
||||
installTipsHelpFunc(rootCmd)
|
||||
rootCmd.SilenceErrors = true
|
||||
|
||||
RegisterGlobalFlags(rootCmd.PersistentFlags(), &cfg.globals)
|
||||
rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
|
||||
cmd.SilenceUsage = true
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(cmdconfig.NewCmdConfig(f))
|
||||
rootCmd.AddCommand(auth.NewCmdAuth(f))
|
||||
rootCmd.AddCommand(profile.NewCmdProfile(f))
|
||||
rootCmd.AddCommand(doctor.NewCmdDoctor(f))
|
||||
rootCmd.AddCommand(api.NewCmdApiWithContext(ctx, f, nil))
|
||||
rootCmd.AddCommand(schema.NewCmdSchema(f, nil))
|
||||
rootCmd.AddCommand(completion.NewCmdCompletion(f))
|
||||
rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f))
|
||||
service.RegisterServiceCommandsWithContext(ctx, rootCmd, f)
|
||||
shortcuts.RegisterShortcutsWithContext(ctx, rootCmd, f)
|
||||
|
||||
// Prune commands incompatible with strict mode.
|
||||
if mode := f.ResolveStrictMode(ctx); mode.IsActive() {
|
||||
pruneForStrictMode(rootCmd, mode)
|
||||
}
|
||||
|
||||
return f, rootCmd
|
||||
}
|
||||
63
cmd/build_api_test.go
Normal file
63
cmd/build_api_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
)
|
||||
|
||||
// noopKeychain is a zero-side-effect KeychainAccess for exercising
|
||||
// WithKeychain without touching the platform keychain.
|
||||
type noopKeychain struct{}
|
||||
|
||||
func (noopKeychain) Get(service, account string) (string, error) { return "", nil }
|
||||
func (noopKeychain) Set(service, account, value string) error { return nil }
|
||||
func (noopKeychain) Remove(service, account string) error { return nil }
|
||||
|
||||
// TestBuild_ExternalAPI asserts the library surface that external consumers
|
||||
// (e.g. cli-server) depend on: Build composes a root command from an
|
||||
// InvocationContext plus BuildOptions (WithIO, WithKeychain, HideProfile),
|
||||
// and SetDefaultFS swaps the global VFS. This test is the contract guard.
|
||||
func TestBuild_ExternalAPI(t *testing.T) {
|
||||
// Exercise SetDefaultFS both directions. Passing nil restores the OS FS.
|
||||
SetDefaultFS(vfs.OsFs{})
|
||||
SetDefaultFS(nil)
|
||||
|
||||
var in, out, errOut bytes.Buffer
|
||||
rootCmd := Build(
|
||||
context.Background(),
|
||||
cmdutil.InvocationContext{},
|
||||
WithIO(&in, &out, &errOut),
|
||||
WithKeychain(noopKeychain{}),
|
||||
HideProfile(true),
|
||||
)
|
||||
|
||||
if rootCmd == nil {
|
||||
t.Fatal("Build returned nil root command")
|
||||
}
|
||||
if rootCmd.Use != "lark-cli" {
|
||||
t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli")
|
||||
}
|
||||
if len(rootCmd.Commands()) == 0 {
|
||||
t.Error("Build produced a root command with no subcommands")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuild_NoOptions guards against regression of the nil-streams panic:
|
||||
// calling Build without WithIO must fall back to SystemIO rather than
|
||||
// deref nil at rootCmd.SetIn/Out/Err.
|
||||
func TestBuild_NoOptions(t *testing.T) {
|
||||
rootCmd := Build(context.Background(), cmdutil.InvocationContext{})
|
||||
if rootCmd == nil {
|
||||
t.Fatal("Build returned nil root command")
|
||||
}
|
||||
if rootCmd.Use != "lark-cli" {
|
||||
t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli")
|
||||
}
|
||||
}
|
||||
@@ -3,15 +3,38 @@
|
||||
|
||||
package cmd
|
||||
|
||||
import "github.com/spf13/pflag"
|
||||
import (
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// GlobalOptions are the root-level flags shared by bootstrap parsing and the
|
||||
// actual Cobra command tree.
|
||||
// actual Cobra command tree. Profile is the parsed --profile value; HideProfile
|
||||
// is a build-time policy — when true, --profile stays parseable but is marked
|
||||
// hidden from help and shell completion.
|
||||
type GlobalOptions struct {
|
||||
Profile string
|
||||
Profile string
|
||||
HideProfile bool
|
||||
}
|
||||
|
||||
// RegisterGlobalFlags registers the root-level persistent flags.
|
||||
// RegisterGlobalFlags registers the root-level persistent flags on fs and
|
||||
// applies any visibility policy encoded in opts. Pure function: no disk,
|
||||
// network, or environment reads — the caller decides HideProfile.
|
||||
func RegisterGlobalFlags(fs *pflag.FlagSet, opts *GlobalOptions) {
|
||||
fs.StringVar(&opts.Profile, "profile", "", "use a specific profile")
|
||||
if opts.HideProfile {
|
||||
_ = fs.MarkHidden("profile")
|
||||
}
|
||||
}
|
||||
|
||||
// isSingleAppMode reports whether the on-disk config has at most one app.
|
||||
// Missing configs are treated as single-app since --profile is meaningless
|
||||
// until at least two profiles exist. Intended for the Execute entry point —
|
||||
// buildInternal must not call this directly to stay state-free.
|
||||
func isSingleAppMode() bool {
|
||||
raw, err := core.LoadMultiAppConfig()
|
||||
if err != nil || raw == nil {
|
||||
return true
|
||||
}
|
||||
return len(raw.Apps) <= 1
|
||||
}
|
||||
|
||||
110
cmd/global_flags_test.go
Normal file
110
cmd/global_flags_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
func testStreams() BuildOption { return WithIO(os.Stdin, os.Stdout, os.Stderr) }
|
||||
|
||||
func TestRegisterGlobalFlags_PolicyVisible(t *testing.T) {
|
||||
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
opts := &GlobalOptions{}
|
||||
RegisterGlobalFlags(fs, opts)
|
||||
|
||||
flag := fs.Lookup("profile")
|
||||
if flag == nil {
|
||||
t.Fatal("profile flag should be registered")
|
||||
}
|
||||
if flag.Hidden {
|
||||
t.Fatal("profile flag should be visible when HideProfile is false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterGlobalFlags_PolicyHidden(t *testing.T) {
|
||||
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
opts := &GlobalOptions{HideProfile: true}
|
||||
RegisterGlobalFlags(fs, opts)
|
||||
|
||||
flag := fs.Lookup("profile")
|
||||
if flag == nil {
|
||||
t.Fatal("profile flag should be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("profile flag should be hidden when HideProfile is true")
|
||||
}
|
||||
if err := fs.Parse([]string{"--profile", "x"}); err != nil {
|
||||
t.Fatalf("Parse() error = %v; hidden flag should still parse", err)
|
||||
}
|
||||
if opts.Profile != "x" {
|
||||
t.Fatalf("opts.Profile = %q, want %q", opts.Profile, "x")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSingleAppMode_NoConfig(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
if !isSingleAppMode() {
|
||||
t.Fatal("isSingleAppMode() = false, want true when no config exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSingleAppMode_SingleApp(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
saveAppsForTest(t, []core.AppConfig{
|
||||
{Name: "default", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu},
|
||||
})
|
||||
if !isSingleAppMode() {
|
||||
t.Fatal("isSingleAppMode() = false, want true for single-app config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSingleAppMode_MultiApp(t *testing.T) {
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
saveAppsForTest(t, []core.AppConfig{
|
||||
{Name: "a", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu},
|
||||
{Name: "b", AppId: "cli_b", AppSecret: core.PlainSecret("y"), Brand: core.BrandFeishu},
|
||||
})
|
||||
if isSingleAppMode() {
|
||||
t.Fatal("isSingleAppMode() = true, want false for multi-app config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInternal_HideProfileOption(t *testing.T) {
|
||||
_, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams(), HideProfile(true))
|
||||
|
||||
flag := root.PersistentFlags().Lookup("profile")
|
||||
if flag == nil {
|
||||
t.Fatal("profile flag should be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("profile flag should be hidden when HideProfile(true) is applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInternal_DefaultShowsProfileFlag(t *testing.T) {
|
||||
_, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams())
|
||||
|
||||
flag := root.PersistentFlags().Lookup("profile")
|
||||
if flag == nil {
|
||||
t.Fatal("profile flag should be registered by default")
|
||||
}
|
||||
if flag.Hidden {
|
||||
t.Fatal("profile flag should be visible by default")
|
||||
}
|
||||
}
|
||||
|
||||
func saveAppsForTest(t *testing.T, apps []core.AppConfig) {
|
||||
t.Helper()
|
||||
multi := &core.MultiAppConfig{CurrentApp: apps[0].Name, Apps: apps}
|
||||
if err := core.SaveMultiAppConfig(multi); err != nil {
|
||||
t.Fatalf("SaveMultiAppConfig() error = %v", err)
|
||||
}
|
||||
}
|
||||
18
cmd/init.go
Normal file
18
cmd/init.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmd
|
||||
|
||||
import "github.com/larksuite/cli/internal/vfs"
|
||||
|
||||
// SetDefaultFS replaces the global filesystem implementation used by internal
|
||||
// packages. The provided fs must implement the vfs.FS interface. If fs is nil,
|
||||
// the default OS filesystem is restored.
|
||||
//
|
||||
// Call this before Build or Execute to take effect.
|
||||
func SetDefaultFS(fs vfs.FS) {
|
||||
if fs == nil {
|
||||
fs = vfs.OsFs{}
|
||||
}
|
||||
vfs.DefaultFS = fs
|
||||
}
|
||||
64
cmd/root.go
64
cmd/root.go
@@ -14,15 +14,6 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/larksuite/cli/cmd/api"
|
||||
"github.com/larksuite/cli/cmd/auth"
|
||||
"github.com/larksuite/cli/cmd/completion"
|
||||
cmdconfig "github.com/larksuite/cli/cmd/config"
|
||||
"github.com/larksuite/cli/cmd/doctor"
|
||||
"github.com/larksuite/cli/cmd/profile"
|
||||
"github.com/larksuite/cli/cmd/schema"
|
||||
"github.com/larksuite/cli/cmd/service"
|
||||
cmdupdate "github.com/larksuite/cli/cmd/update"
|
||||
internalauth "github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/build"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
@@ -30,7 +21,6 @@ import (
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/registry"
|
||||
"github.com/larksuite/cli/internal/update"
|
||||
"github.com/larksuite/cli/shortcuts"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -95,38 +85,13 @@ func Execute() int {
|
||||
fmt.Fprintln(os.Stderr, "Error:", err)
|
||||
return 1
|
||||
}
|
||||
f := cmdutil.NewDefault(inv)
|
||||
configureFlagCompletions(os.Args)
|
||||
|
||||
globals := &GlobalOptions{Profile: inv.Profile}
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "lark-cli",
|
||||
Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls",
|
||||
Long: rootLong,
|
||||
Version: build.Version,
|
||||
}
|
||||
installTipsHelpFunc(rootCmd)
|
||||
rootCmd.SilenceErrors = true
|
||||
|
||||
RegisterGlobalFlags(rootCmd.PersistentFlags(), globals)
|
||||
rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
|
||||
cmd.SilenceUsage = true
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(cmdconfig.NewCmdConfig(f))
|
||||
rootCmd.AddCommand(auth.NewCmdAuth(f))
|
||||
rootCmd.AddCommand(profile.NewCmdProfile(f))
|
||||
rootCmd.AddCommand(doctor.NewCmdDoctor(f))
|
||||
rootCmd.AddCommand(api.NewCmdApi(f, nil))
|
||||
rootCmd.AddCommand(schema.NewCmdSchema(f, nil))
|
||||
rootCmd.AddCommand(completion.NewCmdCompletion(f))
|
||||
rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f))
|
||||
service.RegisterServiceCommands(rootCmd, f)
|
||||
shortcuts.RegisterShortcuts(rootCmd, f)
|
||||
|
||||
// Prune commands incompatible with strict mode.
|
||||
if mode := f.ResolveStrictMode(context.Background()); mode.IsActive() {
|
||||
pruneForStrictMode(rootCmd, mode)
|
||||
}
|
||||
f, rootCmd := buildInternal(
|
||||
context.Background(), inv,
|
||||
WithIO(os.Stdin, os.Stdout, os.Stderr),
|
||||
HideProfile(isSingleAppMode()),
|
||||
)
|
||||
|
||||
// --- Update check (non-blocking) ---
|
||||
if !isCompletionCommand(os.Args) {
|
||||
@@ -190,6 +155,12 @@ func isCompletionCommand(args []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// configureFlagCompletions enables cmdutil.RegisterFlagCompletion only when
|
||||
// the invocation will actually serve a __complete request.
|
||||
func configureFlagCompletions(args []string) {
|
||||
cmdutil.SetFlagCompletionsDisabled(!isCompletionCommand(args))
|
||||
}
|
||||
|
||||
// handleRootError dispatches a command error to the appropriate handler
|
||||
// and returns the process exit code.
|
||||
func handleRootError(f *cmdutil.Factory, err error) int {
|
||||
@@ -277,10 +248,19 @@ func writeSecurityPolicyError(w io.Writer, spErr *internalauth.SecurityPolicyErr
|
||||
}
|
||||
|
||||
// installTipsHelpFunc wraps the default help function to append a TIPS section
|
||||
// when a command has tips set via cmdutil.SetTips.
|
||||
// when a command has tips set via cmdutil.SetTips. It also force-shows global
|
||||
// flags that are normally hidden in single-app mode (currently --profile)
|
||||
// when rendering the root command's own help, so users discovering the CLI
|
||||
// still see them at `lark-cli --help`.
|
||||
func installTipsHelpFunc(root *cobra.Command) {
|
||||
defaultHelp := root.HelpFunc()
|
||||
root.SetHelpFunc(func(cmd *cobra.Command, args []string) {
|
||||
if cmd == root {
|
||||
if f := root.PersistentFlags().Lookup("profile"); f != nil && f.Hidden {
|
||||
f.Hidden = false
|
||||
defer func() { f.Hidden = true }()
|
||||
}
|
||||
}
|
||||
defaultHelp(cmd, args)
|
||||
tips := cmdutil.GetTips(cmd)
|
||||
if len(tips) == 0 {
|
||||
|
||||
@@ -135,10 +135,12 @@ func newStrictModeDefaultFactory(t *testing.T, profile string, mode core.StrictM
|
||||
t.Fatalf("SaveMultiAppConfig() error = %v", err)
|
||||
}
|
||||
|
||||
f := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile})
|
||||
stdout := &bytes.Buffer{}
|
||||
stderr := &bytes.Buffer{}
|
||||
f.IOStreams = &cmdutil.IOStreams{In: nil, Out: stdout, ErrOut: stderr}
|
||||
f := cmdutil.NewDefault(
|
||||
cmdutil.NewIOStreams(&bytes.Buffer{}, stdout, stderr),
|
||||
cmdutil.InvocationContext{Profile: profile},
|
||||
)
|
||||
return f, stdout, stderr
|
||||
}
|
||||
|
||||
|
||||
@@ -196,3 +196,28 @@ func TestRootLong_AgentSkillsLinkTargetsReadmeSection(t *testing.T) {
|
||||
t.Fatalf("root help should not reference the removed install-ai-agent-skills anchor, got:\n%s", rootLong)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureFlagCompletions(t *testing.T) {
|
||||
t.Cleanup(func() { cmdutil.SetFlagCompletionsDisabled(false) })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantDisabled bool
|
||||
}{
|
||||
{"plain command", []string{"im", "+send"}, true},
|
||||
{"help flag", []string{"im", "--help"}, true},
|
||||
{"no args", []string{}, true},
|
||||
{"__complete request", []string{"__complete", "im", "+send", ""}, false},
|
||||
{"completion subcommand", []string{"completion", "bash"}, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cmdutil.SetFlagCompletionsDisabled(!tc.wantDisabled)
|
||||
configureFlagCompletions(tc.args)
|
||||
if got := cmdutil.FlagCompletionsDisabled(); got != tc.wantDisabled {
|
||||
t.Fatalf("FlagCompletionsDisabled() = %v, want %v", got, tc.wantDisabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/registry"
|
||||
"github.com/larksuite/cli/internal/util"
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
// SchemaOptions holds all inputs for the schema command.
|
||||
type SchemaOptions struct {
|
||||
Factory *cmdutil.Factory
|
||||
Ctx context.Context
|
||||
|
||||
// Positional args
|
||||
Path string
|
||||
@@ -41,7 +44,7 @@ func printServices(w io.Writer) {
|
||||
fmt.Fprintf(w, "\n%sUsage: lark-cli schema <service>.<resource>.<method>%s\n", output.Dim, output.Reset)
|
||||
}
|
||||
|
||||
func printResourceList(w io.Writer, spec map[string]interface{}) {
|
||||
func printResourceList(w io.Writer, spec map[string]interface{}, mode core.StrictMode) {
|
||||
name := registry.GetStrFromMap(spec, "name")
|
||||
version := registry.GetStrFromMap(spec, "version")
|
||||
title := registry.GetStrFromMap(spec, "title")
|
||||
@@ -55,9 +58,13 @@ func printResourceList(w io.Writer, spec map[string]interface{}) {
|
||||
|
||||
resources, _ := spec["resources"].(map[string]interface{})
|
||||
for _, resName := range sortedKeys(resources) {
|
||||
fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset)
|
||||
resMap, _ := resources[resName].(map[string]interface{})
|
||||
methods, _ := resMap["methods"].(map[string]interface{})
|
||||
methods = filterMethodsByStrictMode(methods, mode)
|
||||
if len(methods) == 0 {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset)
|
||||
for _, methodName := range sortedKeys(methods) {
|
||||
m, _ := methods[methodName].(map[string]interface{})
|
||||
httpMethod := registry.GetStrFromMap(m, "httpMethod")
|
||||
@@ -359,6 +366,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co
|
||||
if len(args) > 0 {
|
||||
opts.Path = args[0]
|
||||
}
|
||||
opts.Ctx = cmd.Context()
|
||||
if runF != nil {
|
||||
return runF(opts)
|
||||
}
|
||||
@@ -367,9 +375,9 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co
|
||||
}
|
||||
cmdutil.DisableAuthCheck(cmd)
|
||||
|
||||
cmd.ValidArgsFunction = completeSchemaPath
|
||||
cmd.ValidArgsFunction = completeSchemaPath(f)
|
||||
cmd.Flags().StringVar(&opts.Format, "format", "json", "output format: json (default) | pretty")
|
||||
_ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"json", "pretty"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
|
||||
@@ -379,78 +387,86 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co
|
||||
// completeSchemaPath provides tab-completion for the schema path argument.
|
||||
// It handles dotted resource names (e.g. app.table.fields) by iterating all
|
||||
// resources and classifying each as a prefix-match or fully-matched.
|
||||
func completeSchemaPath(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) > 0 {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
func completeSchemaPath(f *cmdutil.Factory) func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) > 0 {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
parts := strings.Split(toComplete, ".")
|
||||
parts := strings.Split(toComplete, ".")
|
||||
|
||||
// Level 1: complete service names
|
||||
if len(parts) <= 1 {
|
||||
var completions []string
|
||||
for _, s := range registry.ListFromMetaProjects() {
|
||||
if strings.HasPrefix(s, toComplete) {
|
||||
completions = append(completions, s+".")
|
||||
// Level 1: complete service names
|
||||
if len(parts) <= 1 {
|
||||
var completions []string
|
||||
for _, s := range registry.ListFromMetaProjects() {
|
||||
if strings.HasPrefix(s, toComplete) {
|
||||
completions = append(completions, s+".")
|
||||
}
|
||||
}
|
||||
return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace
|
||||
}
|
||||
|
||||
serviceName := parts[0]
|
||||
spec := registry.LoadFromMeta(serviceName)
|
||||
if spec == nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
mode := f.ResolveStrictMode(cmd.Context())
|
||||
spec = filterSpecByStrictMode(spec, mode)
|
||||
resources, _ := spec["resources"].(map[string]interface{})
|
||||
if resources == nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
afterService := strings.Join(parts[1:], ".")
|
||||
completions := completeSchemaPathForSpec(serviceName, resources, afterService)
|
||||
|
||||
allTrailingDot := len(completions) > 0
|
||||
for _, c := range completions {
|
||||
if !strings.HasSuffix(c, ".") {
|
||||
allTrailingDot = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace
|
||||
directive := cobra.ShellCompDirectiveNoFileComp
|
||||
if allTrailingDot {
|
||||
directive |= cobra.ShellCompDirectiveNoSpace
|
||||
}
|
||||
return completions, directive
|
||||
}
|
||||
}
|
||||
|
||||
serviceName := parts[0]
|
||||
spec := registry.LoadFromMeta(serviceName)
|
||||
if spec == nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
resources, _ := spec["resources"].(map[string]interface{})
|
||||
if resources == nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
||||
// afterService = everything user typed after "serviceName."
|
||||
afterService := strings.Join(parts[1:], ".")
|
||||
|
||||
func completeSchemaPathForSpec(serviceName string, resources map[string]interface{}, afterService string) []string {
|
||||
var completions []string
|
||||
|
||||
for resName, resVal := range resources {
|
||||
if strings.HasPrefix(resName, afterService) {
|
||||
// afterService is a prefix of this resource name → resource candidate
|
||||
completions = append(completions, serviceName+"."+resName+".")
|
||||
} else if strings.HasPrefix(afterService, resName+".") {
|
||||
// This resource is fully matched; remainder is method prefix
|
||||
methodPrefix := afterService[len(resName)+1:]
|
||||
resMap, _ := resVal.(map[string]interface{})
|
||||
if resMap == nil {
|
||||
continue
|
||||
}
|
||||
methods, _ := resMap["methods"].(map[string]interface{})
|
||||
for methodName := range methods {
|
||||
if strings.HasPrefix(methodName, methodPrefix) {
|
||||
completions = append(completions, serviceName+"."+resName+"."+methodName)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(afterService, resName+".") {
|
||||
continue
|
||||
}
|
||||
methodPrefix := afterService[len(resName)+1:]
|
||||
resMap, _ := resVal.(map[string]interface{})
|
||||
if resMap == nil {
|
||||
continue
|
||||
}
|
||||
methods, _ := resMap["methods"].(map[string]interface{})
|
||||
for methodName := range methods {
|
||||
if strings.HasPrefix(methodName, methodPrefix) {
|
||||
completions = append(completions, serviceName+"."+resName+"."+methodName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(completions)
|
||||
|
||||
// If all completions end with ".", user is still navigating resources → NoSpace
|
||||
allTrailingDot := len(completions) > 0
|
||||
for _, c := range completions {
|
||||
if !strings.HasSuffix(c, ".") {
|
||||
allTrailingDot = false
|
||||
break
|
||||
}
|
||||
}
|
||||
directive := cobra.ShellCompDirectiveNoFileComp
|
||||
if allTrailingDot {
|
||||
directive |= cobra.ShellCompDirectiveNoSpace
|
||||
}
|
||||
return completions, directive
|
||||
return completions
|
||||
}
|
||||
|
||||
func schemaRun(opts *SchemaOptions) error {
|
||||
out := opts.Factory.IOStreams.Out
|
||||
mode := opts.Factory.ResolveStrictMode(opts.Ctx)
|
||||
|
||||
if opts.Path == "" {
|
||||
printServices(out)
|
||||
@@ -469,9 +485,9 @@ func schemaRun(opts *SchemaOptions) error {
|
||||
|
||||
if len(parts) == 1 {
|
||||
if opts.Format == "pretty" {
|
||||
printResourceList(out, spec)
|
||||
printResourceList(out, spec, mode)
|
||||
} else {
|
||||
output.PrintJson(out, spec)
|
||||
output.PrintJson(out, filterSpecByStrictMode(spec, mode))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -492,6 +508,7 @@ func schemaRun(opts *SchemaOptions) error {
|
||||
if opts.Format == "pretty" {
|
||||
fmt.Fprintf(out, "%s%s.%s%s\n\n", output.Bold, serviceName, resName, output.Reset)
|
||||
methods, _ := resource["methods"].(map[string]interface{})
|
||||
methods = filterMethodsByStrictMode(methods, mode)
|
||||
for _, mName := range sortedKeys(methods) {
|
||||
m, _ := methods[mName].(map[string]interface{})
|
||||
httpMethod := registry.GetStrFromMap(m, "httpMethod")
|
||||
@@ -500,13 +517,26 @@ func schemaRun(opts *SchemaOptions) error {
|
||||
}
|
||||
fmt.Fprintf(out, "\n%sUsage: lark-cli schema %s.%s.<method>%s\n", output.Dim, serviceName, resName, output.Reset)
|
||||
} else {
|
||||
output.PrintJson(out, resource)
|
||||
// For JSON output, filter methods in a copy to avoid mutating the registry.
|
||||
if mode.IsActive() {
|
||||
filtered := make(map[string]interface{})
|
||||
for k, v := range resource {
|
||||
filtered[k] = v
|
||||
}
|
||||
if methods, ok := resource["methods"].(map[string]interface{}); ok {
|
||||
filtered["methods"] = filterMethodsByStrictMode(methods, mode)
|
||||
}
|
||||
output.PrintJson(out, filtered)
|
||||
} else {
|
||||
output.PrintJson(out, resource)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
methodName := remaining[0]
|
||||
methods, _ := resource["methods"].(map[string]interface{})
|
||||
methods = filterMethodsByStrictMode(methods, mode)
|
||||
method, ok := methods[methodName].(map[string]interface{})
|
||||
if !ok {
|
||||
var mNames []string
|
||||
@@ -525,3 +555,67 @@ func schemaRun(opts *SchemaOptions) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterSpecByStrictMode returns a shallow copy of spec with each resource's methods
|
||||
// filtered by strict mode. Returns the original spec when strict mode is off.
|
||||
func filterSpecByStrictMode(spec map[string]interface{}, mode core.StrictMode) map[string]interface{} {
|
||||
if !mode.IsActive() {
|
||||
return spec
|
||||
}
|
||||
result := make(map[string]interface{}, len(spec))
|
||||
for k, v := range spec {
|
||||
result[k] = v
|
||||
}
|
||||
resources, _ := spec["resources"].(map[string]interface{})
|
||||
if resources == nil {
|
||||
return result
|
||||
}
|
||||
filteredRes := make(map[string]interface{}, len(resources))
|
||||
for resName, resVal := range resources {
|
||||
resMap, ok := resVal.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
methods, _ := resMap["methods"].(map[string]interface{})
|
||||
filtered := filterMethodsByStrictMode(methods, mode)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
resCopy := make(map[string]interface{}, len(resMap))
|
||||
for k, v := range resMap {
|
||||
resCopy[k] = v
|
||||
}
|
||||
resCopy["methods"] = filtered
|
||||
filteredRes[resName] = resCopy
|
||||
}
|
||||
result["resources"] = filteredRes
|
||||
return result
|
||||
}
|
||||
|
||||
// filterMethodsByStrictMode removes methods incompatible with the active strict mode.
|
||||
// Returns the original map unmodified when strict mode is off.
|
||||
func filterMethodsByStrictMode(methods map[string]interface{}, mode core.StrictMode) map[string]interface{} {
|
||||
if !mode.IsActive() || methods == nil {
|
||||
return methods
|
||||
}
|
||||
token := registry.IdentityToAccessToken(string(mode.ForcedIdentity()))
|
||||
filtered := make(map[string]interface{}, len(methods))
|
||||
for name, val := range methods {
|
||||
m, ok := val.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tokens, _ := m["accessTokens"].([]interface{})
|
||||
if tokens == nil {
|
||||
filtered[name] = val
|
||||
continue
|
||||
}
|
||||
for _, t := range tokens {
|
||||
if ts, ok := t.(string); ok && ts == token {
|
||||
filtered[name] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -182,3 +182,49 @@ func TestHasFileFields(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteSchemaPathForSpec(t *testing.T) {
|
||||
resources := map[string]interface{}{
|
||||
"records": map[string]interface{}{
|
||||
"methods": map[string]interface{}{
|
||||
"create": map[string]interface{}{},
|
||||
"list": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
"record_permissions": map[string]interface{}{
|
||||
"methods": map[string]interface{}{
|
||||
"get": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := completeSchemaPathForSpec("base", resources, "records.cr")
|
||||
if len(got) != 1 || got[0] != "base.records.create" {
|
||||
t.Fatalf("completions = %v, want [base.records.create]", got)
|
||||
}
|
||||
|
||||
got = completeSchemaPathForSpec("base", resources, "record")
|
||||
if len(got) != 2 || got[0] != "base.record_permissions." || got[1] != "base.records." {
|
||||
t.Fatalf("resource completions = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterSpecByStrictMode_RemovesIncompatibleMethodsFromCompletionSource(t *testing.T) {
|
||||
spec := map[string]interface{}{
|
||||
"resources": map[string]interface{}{
|
||||
"records": map[string]interface{}{
|
||||
"methods": map[string]interface{}{
|
||||
"list": map[string]interface{}{"accessTokens": []interface{}{"tenant"}},
|
||||
"create": map[string]interface{}{"accessTokens": []interface{}{"user"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filtered := filterSpecByStrictMode(spec, core.StrictModeBot)
|
||||
resources, _ := filtered["resources"].(map[string]interface{})
|
||||
got := completeSchemaPathForSpec("base", resources, "records.")
|
||||
if len(got) != 1 || got[0] != "base.records.list" {
|
||||
t.Fatalf("filtered completions = %v, want [base.records.list]", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ import (
|
||||
|
||||
// RegisterServiceCommands registers all service commands from from_meta specs.
|
||||
func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) {
|
||||
RegisterServiceCommandsWithContext(context.Background(), parent, f)
|
||||
}
|
||||
|
||||
func RegisterServiceCommandsWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) {
|
||||
for _, project := range registry.ListFromMetaProjects() {
|
||||
spec := registry.LoadFromMeta(project)
|
||||
if spec == nil {
|
||||
@@ -38,11 +42,15 @@ func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) {
|
||||
if resources == nil {
|
||||
continue
|
||||
}
|
||||
registerService(parent, spec, resources, f)
|
||||
registerServiceWithContext(ctx, parent, spec, resources, f)
|
||||
}
|
||||
}
|
||||
|
||||
func registerService(parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) {
|
||||
registerServiceWithContext(context.Background(), parent, spec, resources, f)
|
||||
}
|
||||
|
||||
func registerServiceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) {
|
||||
specName := registry.GetStrFromMap(spec, "name")
|
||||
specDesc := registry.GetServiceDescription(specName, "en")
|
||||
if specDesc == "" {
|
||||
@@ -70,11 +78,11 @@ func registerService(parent *cobra.Command, spec map[string]interface{}, resourc
|
||||
if resMap == nil {
|
||||
continue
|
||||
}
|
||||
registerResource(svc, spec, resName, resMap, f)
|
||||
registerResourceWithContext(ctx, svc, spec, resName, resMap, f)
|
||||
}
|
||||
}
|
||||
|
||||
func registerResource(parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) {
|
||||
func registerResourceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) {
|
||||
res := &cobra.Command{
|
||||
Use: name,
|
||||
Short: name + " operations",
|
||||
@@ -87,7 +95,7 @@ func registerResource(parent *cobra.Command, spec map[string]interface{}, name s
|
||||
if methodMap == nil {
|
||||
continue
|
||||
}
|
||||
registerMethod(res, spec, methodMap, methodName, name, f)
|
||||
registerMethodWithContext(ctx, res, spec, methodMap, methodName, name, f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,12 +128,16 @@ func detectFileFields(method map[string]interface{}) []string {
|
||||
return cmdutil.DetectFileFields(method)
|
||||
}
|
||||
|
||||
func registerMethod(parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) {
|
||||
parent.AddCommand(NewCmdServiceMethod(f, spec, method, name, resName, nil))
|
||||
func registerMethodWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) {
|
||||
parent.AddCommand(NewCmdServiceMethodWithContext(ctx, f, spec, method, name, resName, nil))
|
||||
}
|
||||
|
||||
// NewCmdServiceMethod creates a command for a dynamically registered service method.
|
||||
func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command {
|
||||
return NewCmdServiceMethodWithContext(context.Background(), f, spec, method, name, resName, runF)
|
||||
}
|
||||
|
||||
func NewCmdServiceMethodWithContext(ctx context.Context, f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command {
|
||||
desc := registry.GetStrFromMap(method, "description")
|
||||
httpMethod := registry.GetStrFromMap(method, "httpMethod")
|
||||
specName := registry.GetStrFromMap(spec, "name")
|
||||
@@ -159,7 +171,7 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{}
|
||||
case "POST", "PUT", "PATCH", "DELETE":
|
||||
cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)")
|
||||
}
|
||||
cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)")
|
||||
cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr)
|
||||
cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses")
|
||||
cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages")
|
||||
cmd.Flags().IntVar(&opts.PageLimit, "page-limit", 10, "max pages to fetch with --page-all (0 = unlimited)")
|
||||
@@ -177,11 +189,7 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{}
|
||||
cmd.Flags().StringVar(&opts.File, "file", "", "file to upload ([field=]path, supports - for stdin)")
|
||||
}
|
||||
}
|
||||
|
||||
_ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
_ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
|
||||
|
||||
@@ -121,6 +121,24 @@ func TestRegisterService_MergesExistingCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCmdServiceMethod_StrictModeHidesAsFlag(t *testing.T) {
|
||||
f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{
|
||||
AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2,
|
||||
})
|
||||
|
||||
cmd := NewCmdServiceMethod(f, driveSpec(), driveMethod("GET", nil), "copy", "files", nil)
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("expected --as flag to be hidden in strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "bot" {
|
||||
t.Fatalf("default value = %q, want %q", got, "bot")
|
||||
}
|
||||
}
|
||||
|
||||
// ── NewCmdServiceMethod flags ──
|
||||
|
||||
func TestNewCmdServiceMethod_GETHasNoDataFlag(t *testing.T) {
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
package credential
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
@@ -11,12 +14,28 @@ var (
|
||||
)
|
||||
|
||||
// Register registers a credential Provider.
|
||||
// Providers are consulted in registration order.
|
||||
// Providers are consulted in priority order (lowest value first).
|
||||
// Providers that implement Priority() int are sorted accordingly;
|
||||
// those that do not default to priority 10.
|
||||
// Typically called from init() via blank import.
|
||||
func Register(p Provider) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
providers = append(providers, p)
|
||||
sort.SliceStable(providers, func(i, j int) bool {
|
||||
return providerPriority(providers[i]) < providerPriority(providers[j])
|
||||
})
|
||||
}
|
||||
|
||||
// providerPriority returns the priority of a provider.
|
||||
// If the provider implements interface{ Priority() int }, that value is used;
|
||||
// otherwise 10 is returned as the default priority.
|
||||
// Lower values are consulted first.
|
||||
func providerPriority(p Provider) int {
|
||||
if pp, ok := p.(interface{ Priority() int }); ok {
|
||||
return pp.Priority()
|
||||
}
|
||||
return 10
|
||||
}
|
||||
|
||||
// Providers returns all registered providers (snapshot).
|
||||
|
||||
@@ -37,6 +37,32 @@ func TestRegisterAndProviders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type priorityProvider struct {
|
||||
stubProvider
|
||||
priority int
|
||||
}
|
||||
|
||||
func (p *priorityProvider) Priority() int { return p.priority }
|
||||
|
||||
func TestRegister_PriorityOrder(t *testing.T) {
|
||||
mu.Lock()
|
||||
old := providers
|
||||
providers = nil
|
||||
mu.Unlock()
|
||||
defer func() { mu.Lock(); providers = old; mu.Unlock() }()
|
||||
|
||||
Register(&stubProvider{name: "env"}) // priority 10 (default)
|
||||
Register(&priorityProvider{stubProvider: stubProvider{name: "sidecar"}, priority: 0}) // priority 0 (first)
|
||||
|
||||
got := Providers()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(got))
|
||||
}
|
||||
if got[0].Name() != "sidecar" || got[1].Name() != "env" {
|
||||
t.Errorf("expected sidecar before env, got %s, %s", got[0].Name(), got[1].Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviders_ReturnsSnapshot(t *testing.T) {
|
||||
mu.Lock()
|
||||
old := providers
|
||||
|
||||
131
extension/credential/sidecar/provider.go
Normal file
131
extension/credential/sidecar/provider.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
// Package sidecar provides a noop credential provider for the auth sidecar
|
||||
// proxy mode. When LARKSUITE_CLI_AUTH_PROXY is set, this provider supplies
|
||||
// placeholder credentials so the CLI's auth pipeline can proceed normally.
|
||||
// Real tokens are never present in the sandbox; the sidecar transport
|
||||
// interceptor routes requests to the trusted sidecar process instead.
|
||||
package sidecar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/larksuite/cli/extension/credential"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
"github.com/larksuite/cli/sidecar"
|
||||
)
|
||||
|
||||
// Provider is the noop credential provider for sidecar mode.
|
||||
type Provider struct{}
|
||||
|
||||
func (p *Provider) Name() string { return "sidecar" }
|
||||
func (p *Provider) Priority() int { return 0 }
|
||||
|
||||
// ResolveAccount returns a minimal Account when sidecar mode is active.
|
||||
// The account contains AppID and Brand from environment variables, a
|
||||
// placeholder secret, and SupportedIdentities derived from STRICT_MODE.
|
||||
// Returns nil, nil when sidecar mode is not active (AUTH_PROXY not set).
|
||||
func (p *Provider) ResolveAccount(ctx context.Context) (*credential.Account, error) {
|
||||
proxyAddr := os.Getenv(envvars.CliAuthProxy)
|
||||
if proxyAddr == "" {
|
||||
return nil, nil // not in sidecar mode, skip
|
||||
}
|
||||
|
||||
if err := sidecar.ValidateProxyAddr(proxyAddr); err != nil {
|
||||
return nil, &credential.BlockError{
|
||||
Provider: "sidecar",
|
||||
Reason: fmt.Sprintf("invalid %s %q: %v", envvars.CliAuthProxy, proxyAddr, err),
|
||||
}
|
||||
}
|
||||
|
||||
appID := os.Getenv(envvars.CliAppID)
|
||||
if appID == "" {
|
||||
return nil, &credential.BlockError{
|
||||
Provider: "sidecar",
|
||||
Reason: envvars.CliAuthProxy + " is set but " + envvars.CliAppID + " is missing",
|
||||
}
|
||||
}
|
||||
|
||||
if os.Getenv(envvars.CliProxyKey) == "" {
|
||||
return nil, &credential.BlockError{
|
||||
Provider: "sidecar",
|
||||
Reason: envvars.CliAuthProxy + " is set but " + envvars.CliProxyKey + " is missing",
|
||||
}
|
||||
}
|
||||
|
||||
brand := credential.Brand(os.Getenv(envvars.CliBrand))
|
||||
if brand == "" {
|
||||
brand = credential.BrandFeishu
|
||||
}
|
||||
|
||||
acct := &credential.Account{
|
||||
AppID: appID,
|
||||
AppSecret: credential.NoAppSecret,
|
||||
Brand: brand,
|
||||
}
|
||||
|
||||
// Parse DefaultAs
|
||||
switch id := credential.Identity(os.Getenv(envvars.CliDefaultAs)); id {
|
||||
case "", credential.IdentityAuto:
|
||||
acct.DefaultAs = id
|
||||
case credential.IdentityUser, credential.IdentityBot:
|
||||
acct.DefaultAs = id
|
||||
default:
|
||||
return nil, &credential.BlockError{
|
||||
Provider: "sidecar",
|
||||
Reason: fmt.Sprintf("invalid %s %q (want user, bot, or auto)", envvars.CliDefaultAs, id),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse SupportedIdentities from STRICT_MODE, default to SupportsAll.
|
||||
switch strictMode := os.Getenv(envvars.CliStrictMode); strictMode {
|
||||
case "bot":
|
||||
acct.SupportedIdentities = credential.SupportsBot
|
||||
case "user":
|
||||
acct.SupportedIdentities = credential.SupportsUser
|
||||
case "off", "":
|
||||
acct.SupportedIdentities = credential.SupportsAll
|
||||
default:
|
||||
return nil, &credential.BlockError{
|
||||
Provider: "sidecar",
|
||||
Reason: fmt.Sprintf("invalid %s %q (want bot, user, or off)", envvars.CliStrictMode, strictMode),
|
||||
}
|
||||
}
|
||||
|
||||
return acct, nil
|
||||
}
|
||||
|
||||
// ResolveToken returns a sentinel token whose value encodes the token type.
|
||||
// The transport interceptor reads this sentinel to determine the identity
|
||||
// (user vs bot), strips it, and the sidecar injects the real token.
|
||||
// Returns nil, nil when sidecar mode is not active.
|
||||
func (p *Provider) ResolveToken(ctx context.Context, req credential.TokenSpec) (*credential.Token, error) {
|
||||
if os.Getenv(envvars.CliAuthProxy) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var sentinel string
|
||||
switch req.Type {
|
||||
case credential.TokenTypeUAT:
|
||||
sentinel = sidecar.SentinelUAT
|
||||
case credential.TokenTypeTAT:
|
||||
sentinel = sidecar.SentinelTAT
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &credential.Token{
|
||||
Value: sentinel,
|
||||
Scopes: "", // empty → scope pre-check is skipped
|
||||
Source: "sidecar",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
credential.Register(&Provider{})
|
||||
}
|
||||
188
extension/credential/sidecar/provider_test.go
Normal file
188
extension/credential/sidecar/provider_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
package sidecar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/extension/credential"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
"github.com/larksuite/cli/sidecar"
|
||||
)
|
||||
|
||||
func setEnv(t *testing.T, key, value string) {
|
||||
t.Helper()
|
||||
old, hadOld := os.LookupEnv(key)
|
||||
os.Setenv(key, value)
|
||||
t.Cleanup(func() {
|
||||
if hadOld {
|
||||
os.Setenv(key, old)
|
||||
} else {
|
||||
os.Unsetenv(key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func unsetEnv(t *testing.T, key string) {
|
||||
t.Helper()
|
||||
old, hadOld := os.LookupEnv(key)
|
||||
os.Unsetenv(key)
|
||||
t.Cleanup(func() {
|
||||
if hadOld {
|
||||
os.Setenv(key, old)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveAccount_NotActive(t *testing.T) {
|
||||
unsetEnv(t, envvars.CliAuthProxy)
|
||||
|
||||
p := &Provider{}
|
||||
acct, err := p.ResolveAccount(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if acct != nil {
|
||||
t.Fatal("expected nil account when AUTH_PROXY not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAccount_Active(t *testing.T) {
|
||||
setEnv(t, envvars.CliAuthProxy, "http://127.0.0.1:16384")
|
||||
setEnv(t, envvars.CliProxyKey, "test-key")
|
||||
setEnv(t, envvars.CliAppID, "cli_test123")
|
||||
setEnv(t, envvars.CliBrand, "lark")
|
||||
unsetEnv(t, envvars.CliDefaultAs)
|
||||
unsetEnv(t, envvars.CliStrictMode)
|
||||
|
||||
p := &Provider{}
|
||||
acct, err := p.ResolveAccount(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if acct == nil {
|
||||
t.Fatal("expected non-nil account")
|
||||
}
|
||||
if acct.AppID != "cli_test123" {
|
||||
t.Errorf("AppID = %q, want %q", acct.AppID, "cli_test123")
|
||||
}
|
||||
if acct.Brand != credential.BrandLark {
|
||||
t.Errorf("Brand = %q, want %q", acct.Brand, credential.BrandLark)
|
||||
}
|
||||
if acct.AppSecret != credential.NoAppSecret {
|
||||
t.Errorf("AppSecret should be NoAppSecret, got %q", acct.AppSecret)
|
||||
}
|
||||
if acct.SupportedIdentities != credential.SupportsAll {
|
||||
t.Errorf("SupportedIdentities = %d, want %d (SupportsAll)", acct.SupportedIdentities, credential.SupportsAll)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAccount_MissingProxyKey(t *testing.T) {
|
||||
setEnv(t, envvars.CliAuthProxy, "http://127.0.0.1:16384")
|
||||
unsetEnv(t, envvars.CliProxyKey)
|
||||
setEnv(t, envvars.CliAppID, "cli_test")
|
||||
|
||||
p := &Provider{}
|
||||
_, err := p.ResolveAccount(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error when PROXY_KEY is missing")
|
||||
}
|
||||
if _, ok := err.(*credential.BlockError); !ok {
|
||||
t.Fatalf("expected BlockError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAccount_MissingAppID(t *testing.T) {
|
||||
setEnv(t, envvars.CliAuthProxy, "http://127.0.0.1:16384")
|
||||
setEnv(t, envvars.CliProxyKey, "test-key")
|
||||
unsetEnv(t, envvars.CliAppID)
|
||||
|
||||
p := &Provider{}
|
||||
_, err := p.ResolveAccount(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error when APP_ID is missing")
|
||||
}
|
||||
if _, ok := err.(*credential.BlockError); !ok {
|
||||
t.Fatalf("expected BlockError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAccount_StrictMode(t *testing.T) {
|
||||
setEnv(t, envvars.CliAuthProxy, "http://127.0.0.1:16384")
|
||||
setEnv(t, envvars.CliProxyKey, "test-key")
|
||||
setEnv(t, envvars.CliAppID, "cli_test")
|
||||
|
||||
tests := []struct {
|
||||
mode string
|
||||
want credential.IdentitySupport
|
||||
}{
|
||||
{"bot", credential.SupportsBot},
|
||||
{"user", credential.SupportsUser},
|
||||
{"off", credential.SupportsAll},
|
||||
{"", credential.SupportsAll},
|
||||
}
|
||||
|
||||
p := &Provider{}
|
||||
for _, tt := range tests {
|
||||
t.Run("strict_"+tt.mode, func(t *testing.T) {
|
||||
if tt.mode == "" {
|
||||
unsetEnv(t, envvars.CliStrictMode)
|
||||
} else {
|
||||
setEnv(t, envvars.CliStrictMode, tt.mode)
|
||||
}
|
||||
acct, err := p.ResolveAccount(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if acct.SupportedIdentities != tt.want {
|
||||
t.Errorf("SupportedIdentities = %d, want %d", acct.SupportedIdentities, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToken_NotActive(t *testing.T) {
|
||||
unsetEnv(t, envvars.CliAuthProxy)
|
||||
|
||||
p := &Provider{}
|
||||
tok, err := p.ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeUAT})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tok != nil {
|
||||
t.Fatal("expected nil token when AUTH_PROXY not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToken_Sentinels(t *testing.T) {
|
||||
setEnv(t, envvars.CliAuthProxy, "http://127.0.0.1:16384")
|
||||
setEnv(t, envvars.CliProxyKey, "test-key")
|
||||
|
||||
p := &Provider{}
|
||||
|
||||
// UAT
|
||||
tok, err := p.ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeUAT})
|
||||
if err != nil {
|
||||
t.Fatalf("UAT: unexpected error: %v", err)
|
||||
}
|
||||
if tok.Value != sidecar.SentinelUAT {
|
||||
t.Errorf("UAT value = %q, want %q", tok.Value, sidecar.SentinelUAT)
|
||||
}
|
||||
if tok.Scopes != "" {
|
||||
t.Errorf("UAT scopes should be empty, got %q", tok.Scopes)
|
||||
}
|
||||
|
||||
// TAT
|
||||
tok, err = p.ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeTAT})
|
||||
if err != nil {
|
||||
t.Fatalf("TAT: unexpected error: %v", err)
|
||||
}
|
||||
if tok.Value != sidecar.SentinelTAT {
|
||||
t.Errorf("TAT value = %q, want %q", tok.Value, sidecar.SentinelTAT)
|
||||
}
|
||||
}
|
||||
51
extension/transport/errors.go
Normal file
51
extension/transport/errors.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrAborted is a sentinel matched by errors.Is on any extension-triggered
|
||||
// round-trip abort. Callers that only need to know whether an error was
|
||||
// caused by an extension interception should use:
|
||||
//
|
||||
// if errors.Is(err, transport.ErrAborted) { ... }
|
||||
var ErrAborted = errors.New("round trip aborted by extension")
|
||||
|
||||
// AbortError is returned by the built-in middleware when an AbortableInterceptor
|
||||
// short-circuits a request via PreRoundTripE. It wraps the extension's original
|
||||
// reason and carries the extension's Provider.Name() for traceability.
|
||||
//
|
||||
// Use errors.As to recover the typed error:
|
||||
//
|
||||
// var aErr *transport.AbortError
|
||||
// if errors.As(err, &aErr) {
|
||||
// log.Printf("blocked by %s: %v", aErr.Extension, aErr.Reason)
|
||||
// }
|
||||
//
|
||||
// errors.Is(err, transport.ErrAborted) also works, and errors.Is against the
|
||||
// inner reason still works via Unwrap.
|
||||
type AbortError struct {
|
||||
// Extension is the name of the Provider whose interceptor aborted the
|
||||
// request (from Provider.Name()). May be empty if the provider did not
|
||||
// supply a name.
|
||||
Extension string
|
||||
// Reason is the original non-nil error returned by PreRoundTripE.
|
||||
Reason error
|
||||
}
|
||||
|
||||
func (e *AbortError) Error() string {
|
||||
if e.Extension != "" {
|
||||
return fmt.Sprintf("extension %q aborted round trip: %v", e.Extension, e.Reason)
|
||||
}
|
||||
return fmt.Sprintf("extension aborted round trip: %v", e.Reason)
|
||||
}
|
||||
|
||||
// Unwrap lets errors.Is / errors.As traverse to the underlying Reason.
|
||||
func (e *AbortError) Unwrap() error { return e.Reason }
|
||||
|
||||
// Is enables errors.Is(err, ErrAborted) at any nesting depth.
|
||||
func (e *AbortError) Is(target error) bool { return target == ErrAborted }
|
||||
103
extension/transport/errors_test.go
Normal file
103
extension/transport/errors_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAbortError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AbortError
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with extension name",
|
||||
err: &AbortError{Extension: "audit", Reason: errors.New("bad")},
|
||||
want: `extension "audit" aborted round trip: bad`,
|
||||
},
|
||||
{
|
||||
name: "without extension name",
|
||||
err: &AbortError{Reason: errors.New("bad")},
|
||||
want: "extension aborted round trip: bad",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.err.Error(); got != tt.want {
|
||||
t.Fatalf("Error() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbortError_Unwrap(t *testing.T) {
|
||||
reason := errors.New("bad")
|
||||
e := &AbortError{Reason: reason}
|
||||
if got := e.Unwrap(); got != reason {
|
||||
t.Fatalf("Unwrap() = %v, want %v", got, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbortError_IsErrAborted(t *testing.T) {
|
||||
e := &AbortError{Reason: errors.New("bad")}
|
||||
if !errors.Is(e, ErrAborted) {
|
||||
t.Fatal("errors.Is(e, ErrAborted) = false, want true")
|
||||
}
|
||||
// Sanity: not matched by unrelated sentinels.
|
||||
if errors.Is(e, errors.New("other")) {
|
||||
t.Fatal("errors.Is matched unrelated sentinel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbortError_UnwrapReachesInnerSentinel(t *testing.T) {
|
||||
// Extensions often return typed/sentinel errors; callers should still be
|
||||
// able to errors.Is against those after the middleware wraps them.
|
||||
innerSentinel := errors.New("policy-deny-42")
|
||||
e := &AbortError{Reason: fmt.Errorf("wrapped: %w", innerSentinel)}
|
||||
if !errors.Is(e, innerSentinel) {
|
||||
t.Fatal("errors.Is(e, innerSentinel) = false, want true (Unwrap chain broken)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbortError_As(t *testing.T) {
|
||||
reason := errors.New("bad")
|
||||
base := &AbortError{Extension: "audit", Reason: reason}
|
||||
|
||||
// Direct As.
|
||||
var aErr *AbortError
|
||||
if !errors.As(base, &aErr) {
|
||||
t.Fatal("errors.As(base, *AbortError) = false")
|
||||
}
|
||||
if aErr.Extension != "audit" || aErr.Reason != reason {
|
||||
t.Fatalf("aErr = %+v, want {audit, bad}", aErr)
|
||||
}
|
||||
|
||||
// Nested As: even when the *AbortError is wrapped in another error,
|
||||
// errors.As must still find it via Unwrap chain.
|
||||
wrapped := fmt.Errorf("outer: %w", base)
|
||||
var aErr2 *AbortError
|
||||
if !errors.As(wrapped, &aErr2) {
|
||||
t.Fatal("errors.As(wrapped, *AbortError) = false")
|
||||
}
|
||||
if aErr2 != base {
|
||||
t.Fatalf("aErr2 = %p, want %p", aErr2, base)
|
||||
}
|
||||
|
||||
// errors.Is still matches the sentinel through the outer wrapper.
|
||||
if !errors.Is(wrapped, ErrAborted) {
|
||||
t.Fatal("errors.Is(wrapped, ErrAborted) = false via nested wrap")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrAborted_IsItselfSentinel(t *testing.T) {
|
||||
// Guard against accidental re-assignment of ErrAborted: a bare ErrAborted
|
||||
// value should still satisfy errors.Is(err, ErrAborted) for symmetry.
|
||||
if !errors.Is(ErrAborted, ErrAborted) {
|
||||
t.Fatal("errors.Is(ErrAborted, ErrAborted) = false")
|
||||
}
|
||||
}
|
||||
178
extension/transport/sidecar/interceptor.go
Normal file
178
extension/transport/sidecar/interceptor.go
Normal file
@@ -0,0 +1,178 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
// Package sidecar provides a transport interceptor for the auth sidecar
|
||||
// proxy mode. When LARKSUITE_CLI_AUTH_PROXY is set (an HTTP URL), all
|
||||
// outgoing requests are rewritten to the sidecar address. The interceptor
|
||||
// strips placeholder credentials, injects proxy headers, and signs each
|
||||
// request with HMAC-SHA256. No custom DialContext is needed — Go's
|
||||
// standard http.Transport connects to the sidecar via plain HTTP.
|
||||
package sidecar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/transport"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
"github.com/larksuite/cli/sidecar"
|
||||
)
|
||||
|
||||
// Provider implements transport.Provider for the sidecar mode.
|
||||
type Provider struct{}
|
||||
|
||||
func (p *Provider) Name() string { return "sidecar" }
|
||||
|
||||
// ResolveInterceptor returns a SidecarInterceptor when sidecar mode is active.
|
||||
// Returns nil when sidecar mode is disabled or the proxy address is invalid;
|
||||
// in the latter case a warning is emitted to stderr and requests fall back to
|
||||
// the non-sidecar transport path (where the credential layer will typically
|
||||
// block them for lack of a valid account).
|
||||
func (p *Provider) ResolveInterceptor(ctx context.Context) transport.Interceptor {
|
||||
proxyAddr := os.Getenv(envvars.CliAuthProxy)
|
||||
if proxyAddr == "" {
|
||||
return nil
|
||||
}
|
||||
if err := sidecar.ValidateProxyAddr(proxyAddr); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: invalid %s, sidecar interceptor disabled: %v\n", envvars.CliAuthProxy, err)
|
||||
return nil
|
||||
}
|
||||
key := os.Getenv(envvars.CliProxyKey)
|
||||
return &Interceptor{
|
||||
key: []byte(key),
|
||||
sidecarHost: sidecar.ProxyHost(proxyAddr),
|
||||
}
|
||||
}
|
||||
|
||||
// Interceptor rewrites requests for the sidecar proxy.
|
||||
type Interceptor struct {
|
||||
key []byte // HMAC signing key
|
||||
sidecarHost string // sidecar host:port for URL rewriting
|
||||
}
|
||||
|
||||
// PreRoundTrip rewrites the request for sidecar routing when it carries a
|
||||
// sentinel token. Requests without a sentinel token (e.g. pre-signed download
|
||||
// URLs) are passed through unmodified.
|
||||
//
|
||||
// Supports two auth patterns:
|
||||
// - Standard OpenAPI: Authorization: Bearer <sentinel>
|
||||
// - MCP protocol: X-Lark-MCP-UAT/TAT: <sentinel>
|
||||
func (i *Interceptor) PreRoundTrip(req *http.Request) func(resp *http.Response, err error) {
|
||||
identity, authHeader := detectSentinel(req)
|
||||
if identity == "" {
|
||||
return nil // not a sidecar-managed request, pass through
|
||||
}
|
||||
|
||||
// 1. Buffer the body first, before mutating any request state. A partial
|
||||
// read would sign a truncated body and cause a misleading HMAC mismatch
|
||||
// on the sidecar side; bail out early and let the request fall through
|
||||
// unmodified so the credential layer can surface an actionable error.
|
||||
var bodyBytes []byte
|
||||
if req.Body != nil {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(req.Body)
|
||||
_ = req.Body.Close() // release original body (fd/pipe/etc.) after buffering
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: sidecar interceptor failed to read request body: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if req.GetBody != nil {
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(bodyBytes)), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Save original target (scheme://host)
|
||||
originalScheme := "https"
|
||||
if req.URL.Scheme != "" {
|
||||
originalScheme = req.URL.Scheme
|
||||
}
|
||||
originalHost := req.URL.Host
|
||||
req.Header.Set(sidecar.HeaderProxyTarget, originalScheme+"://"+originalHost)
|
||||
|
||||
// 3. Set identity and tell sidecar which header to inject real token into
|
||||
req.Header.Set(sidecar.HeaderProxyIdentity, identity)
|
||||
req.Header.Set(sidecar.HeaderProxyAuthHeader, authHeader)
|
||||
|
||||
// 4. Strip placeholder auth header(s)
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del(sidecar.HeaderMCPUAT)
|
||||
req.Header.Del(sidecar.HeaderMCPTAT)
|
||||
|
||||
bodySHA := sidecar.BodySHA256(bodyBytes)
|
||||
req.Header.Set(sidecar.HeaderBodySHA256, bodySHA)
|
||||
|
||||
pathAndQuery := req.URL.RequestURI()
|
||||
ts := sidecar.Timestamp()
|
||||
// Cover identity and authHeader in the signature so an on-path attacker
|
||||
// within the replay window cannot flip the injected token's identity or
|
||||
// redirect the token into a different header.
|
||||
sig := sidecar.Sign(i.key, sidecar.CanonicalRequest{
|
||||
Version: sidecar.ProtocolV1,
|
||||
Method: req.Method,
|
||||
Host: originalHost,
|
||||
PathAndQuery: pathAndQuery,
|
||||
BodySHA256: bodySHA,
|
||||
Timestamp: ts,
|
||||
Identity: identity,
|
||||
AuthHeader: authHeader,
|
||||
})
|
||||
req.Header.Set(sidecar.HeaderProxyVersion, sidecar.ProtocolV1)
|
||||
req.Header.Set(sidecar.HeaderProxyTimestamp, ts)
|
||||
req.Header.Set(sidecar.HeaderProxySignature, sig)
|
||||
|
||||
// 5. Rewrite URL to route through sidecar
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = i.sidecarHost
|
||||
|
||||
return nil // no post-hook needed
|
||||
}
|
||||
|
||||
// detectSentinel checks both standard Authorization and MCP auth headers for
|
||||
// sentinel tokens. Returns the identity ("user"/"bot") and the header name
|
||||
// that carried the sentinel.
|
||||
//
|
||||
// Returns ("", "") when the request carries no sentinel token — typically
|
||||
// requests that require no auth (e.g. pre-signed download URLs where the
|
||||
// token is embedded in the URL query parameters).
|
||||
func detectSentinel(req *http.Request) (identity, authHeader string) {
|
||||
// Check standard Authorization: Bearer <sentinel>
|
||||
if auth := req.Header.Get("Authorization"); auth != "" {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
switch token {
|
||||
case sidecar.SentinelUAT:
|
||||
return sidecar.IdentityUser, "Authorization"
|
||||
case sidecar.SentinelTAT:
|
||||
return sidecar.IdentityBot, "Authorization"
|
||||
}
|
||||
}
|
||||
// Check MCP headers: X-Lark-MCP-UAT/TAT: <sentinel>
|
||||
if v := req.Header.Get(sidecar.HeaderMCPUAT); v == sidecar.SentinelUAT {
|
||||
return sidecar.IdentityUser, sidecar.HeaderMCPUAT
|
||||
}
|
||||
if v := req.Header.Get(sidecar.HeaderMCPTAT); v == sidecar.SentinelTAT {
|
||||
return sidecar.IdentityBot, sidecar.HeaderMCPTAT
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proxyAddr := os.Getenv(envvars.CliAuthProxy)
|
||||
if proxyAddr == "" {
|
||||
return
|
||||
}
|
||||
if err := sidecar.ValidateProxyAddr(proxyAddr); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "WARNING: ignoring invalid %s: %v\n", envvars.CliAuthProxy, err)
|
||||
return
|
||||
}
|
||||
transport.Register(&Provider{})
|
||||
}
|
||||
265
extension/transport/sidecar/interceptor_test.go
Normal file
265
extension/transport/sidecar/interceptor_test.go
Normal file
@@ -0,0 +1,265 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
package sidecar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/sidecar"
|
||||
)
|
||||
|
||||
// failingBody is a ReadCloser that errors on Read and tracks Close calls.
|
||||
type failingBody struct {
|
||||
err error
|
||||
closed bool
|
||||
readCall bool
|
||||
}
|
||||
|
||||
func (b *failingBody) Read(p []byte) (int, error) {
|
||||
b.readCall = true
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
func (b *failingBody) Close() error {
|
||||
b.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestInterceptor_PreRoundTrip(t *testing.T) {
|
||||
key := []byte("test-key-for-hmac-signing-32byte!")
|
||||
interceptor := &Interceptor{key: key, sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
body := []byte(`{"msg":"hello"}`)
|
||||
req, _ := http.NewRequest("POST", "https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=chat_id", io.NopCloser(bytes.NewReader(body)))
|
||||
req.Header.Set("Authorization", "Bearer "+sidecar.SentinelUAT)
|
||||
req.Header.Set("X-Cli-Source", "lark-cli")
|
||||
|
||||
post := interceptor.PreRoundTrip(req)
|
||||
|
||||
if post != nil {
|
||||
t.Error("expected nil post hook")
|
||||
}
|
||||
|
||||
// URL should be rewritten to sidecar
|
||||
if req.URL.Scheme != "http" {
|
||||
t.Errorf("scheme = %q, want %q", req.URL.Scheme, "http")
|
||||
}
|
||||
if req.URL.Host != "127.0.0.1:16384" {
|
||||
t.Errorf("host = %q, want %q", req.URL.Host, "127.0.0.1:16384")
|
||||
}
|
||||
|
||||
// Original target should be preserved
|
||||
target := req.Header.Get(sidecar.HeaderProxyTarget)
|
||||
if target != "https://open.feishu.cn" {
|
||||
t.Errorf("target = %q, want %q", target, "https://open.feishu.cn")
|
||||
}
|
||||
|
||||
// Identity should be user (from SentinelUAT)
|
||||
if identity := req.Header.Get(sidecar.HeaderProxyIdentity); identity != sidecar.IdentityUser {
|
||||
t.Errorf("identity = %q, want %q", identity, sidecar.IdentityUser)
|
||||
}
|
||||
|
||||
// Authorization should be stripped
|
||||
if auth := req.Header.Get("Authorization"); auth != "" {
|
||||
t.Errorf("Authorization header should be stripped, got %q", auth)
|
||||
}
|
||||
|
||||
// HMAC headers should be set
|
||||
if sig := req.Header.Get(sidecar.HeaderProxySignature); sig == "" {
|
||||
t.Error("signature header should be set")
|
||||
}
|
||||
if ts := req.Header.Get(sidecar.HeaderProxyTimestamp); ts == "" {
|
||||
t.Error("timestamp header should be set")
|
||||
}
|
||||
if sha := req.Header.Get(sidecar.HeaderBodySHA256); sha == "" {
|
||||
t.Error("body SHA256 header should be set")
|
||||
}
|
||||
if v := req.Header.Get(sidecar.HeaderProxyVersion); v != sidecar.ProtocolV1 {
|
||||
t.Errorf("version header = %q, want %q", v, sidecar.ProtocolV1)
|
||||
}
|
||||
|
||||
// Non-proxy headers should be preserved
|
||||
if src := req.Header.Get("X-Cli-Source"); src != "lark-cli" {
|
||||
t.Errorf("X-Cli-Source should be preserved, got %q", src)
|
||||
}
|
||||
|
||||
// Body should still be readable
|
||||
readBody, _ := io.ReadAll(req.Body)
|
||||
if !bytes.Equal(readBody, body) {
|
||||
t.Errorf("body should be preserved after PreRoundTrip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_BotIdentity(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://open.feishu.cn/open-apis/calendar/v4/events", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+sidecar.SentinelTAT)
|
||||
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
if identity := req.Header.Get(sidecar.HeaderProxyIdentity); identity != sidecar.IdentityBot {
|
||||
t.Errorf("identity = %q, want %q", identity, sidecar.IdentityBot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_NonSentinelToken_PassThrough(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
origURL := "https://some-cdn.example.com/presigned-download?token=abc"
|
||||
req, _ := http.NewRequest("GET", origURL, nil)
|
||||
req.Header.Set("Authorization", "Bearer some-real-token")
|
||||
|
||||
post := interceptor.PreRoundTrip(req)
|
||||
|
||||
// Should NOT be rewritten — no sentinel token
|
||||
if post != nil {
|
||||
t.Error("expected nil post hook for pass-through")
|
||||
}
|
||||
if req.URL.String() != origURL {
|
||||
t.Errorf("URL should be unchanged, got %q", req.URL.String())
|
||||
}
|
||||
if req.Header.Get(sidecar.HeaderProxyTarget) != "" {
|
||||
t.Error("proxy target header should not be set for pass-through")
|
||||
}
|
||||
if req.Header.Get("Authorization") != "Bearer some-real-token" {
|
||||
t.Error("Authorization should be preserved for pass-through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_NoAuth_PassThrough(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
origURL := "https://cdn.feishu.cn/download/file"
|
||||
req, _ := http.NewRequest("GET", origURL, nil)
|
||||
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
// No Authorization header at all — should pass through
|
||||
if req.URL.String() != origURL {
|
||||
t.Errorf("URL should be unchanged for no-auth request, got %q", req.URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_MCP_UAT(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
req, _ := http.NewRequest("POST", "https://mcp.feishu.cn/mcp/v1/tools/call", bytes.NewReader([]byte(`{"jsonrpc":"2.0"}`)))
|
||||
req.Header.Set(sidecar.HeaderMCPUAT, sidecar.SentinelUAT)
|
||||
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
// Should be intercepted and rewritten
|
||||
if req.URL.Host != "127.0.0.1:16384" {
|
||||
t.Errorf("host = %q, want sidecar host", req.URL.Host)
|
||||
}
|
||||
if identity := req.Header.Get(sidecar.HeaderProxyIdentity); identity != sidecar.IdentityUser {
|
||||
t.Errorf("identity = %q, want %q", identity, sidecar.IdentityUser)
|
||||
}
|
||||
if ah := req.Header.Get(sidecar.HeaderProxyAuthHeader); ah != sidecar.HeaderMCPUAT {
|
||||
t.Errorf("auth header = %q, want %q", ah, sidecar.HeaderMCPUAT)
|
||||
}
|
||||
// MCP sentinel should be stripped
|
||||
if v := req.Header.Get(sidecar.HeaderMCPUAT); v != "" {
|
||||
t.Errorf("MCP-UAT should be stripped, got %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_MCP_TAT(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
req, _ := http.NewRequest("POST", "https://mcp.feishu.cn/mcp/v1/tools/call", bytes.NewReader([]byte(`{}`)))
|
||||
req.Header.Set(sidecar.HeaderMCPTAT, sidecar.SentinelTAT)
|
||||
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
if identity := req.Header.Get(sidecar.HeaderProxyIdentity); identity != sidecar.IdentityBot {
|
||||
t.Errorf("identity = %q, want %q", identity, sidecar.IdentityBot)
|
||||
}
|
||||
if ah := req.Header.Get(sidecar.HeaderProxyAuthHeader); ah != sidecar.HeaderMCPTAT {
|
||||
t.Errorf("auth header = %q, want %q", ah, sidecar.HeaderMCPTAT)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_StandardAuth_SetsAuthorizationHeader(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://open.feishu.cn/open-apis/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+sidecar.SentinelUAT)
|
||||
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
if ah := req.Header.Get(sidecar.HeaderProxyAuthHeader); ah != "Authorization" {
|
||||
t.Errorf("auth header = %q, want %q", ah, "Authorization")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInterceptor_BodyReadError verifies that when io.ReadAll on the request
|
||||
// body fails partway, PreRoundTrip skips the rewrite entirely rather than
|
||||
// signing a truncated body (which would produce a misleading HMAC mismatch on
|
||||
// the sidecar side) and releases the original body.
|
||||
func TestInterceptor_BodyReadError(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
const origURL = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
body := &failingBody{err: errors.New("disk gremlin")}
|
||||
|
||||
req, _ := http.NewRequest("POST", origURL, body)
|
||||
req.Header.Set("Authorization", "Bearer "+sidecar.SentinelUAT)
|
||||
|
||||
post := interceptor.PreRoundTrip(req)
|
||||
|
||||
if post != nil {
|
||||
t.Error("expected nil post hook on body read failure")
|
||||
}
|
||||
|
||||
// Original body must be closed to avoid leaking fd/pipe-like resources.
|
||||
if !body.readCall {
|
||||
t.Error("expected ReadAll to have attempted reading from the body")
|
||||
}
|
||||
if !body.closed {
|
||||
t.Error("expected original body to be Close()'d after read failure")
|
||||
}
|
||||
|
||||
// URL must NOT be rewritten — request should fall through to the next
|
||||
// layer (credential) which can surface a meaningful error.
|
||||
if req.URL.String() != origURL {
|
||||
t.Errorf("URL should be unchanged on read failure, got %q", req.URL.String())
|
||||
}
|
||||
|
||||
// No proxy/HMAC headers should leak onto the request.
|
||||
for _, h := range []string{
|
||||
sidecar.HeaderProxyVersion,
|
||||
sidecar.HeaderProxyTarget,
|
||||
sidecar.HeaderProxySignature,
|
||||
sidecar.HeaderProxyTimestamp,
|
||||
sidecar.HeaderBodySHA256,
|
||||
sidecar.HeaderProxyIdentity,
|
||||
sidecar.HeaderProxyAuthHeader,
|
||||
} {
|
||||
if v := req.Header.Get(h); v != "" {
|
||||
t.Errorf("%s should not be set on read failure, got %q", h, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_EmptyBody(t *testing.T) {
|
||||
interceptor := &Interceptor{key: []byte("key"), sidecarHost: "127.0.0.1:16384"}
|
||||
|
||||
req, _ := http.NewRequest("GET", "https://open.feishu.cn/path", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+sidecar.SentinelTAT)
|
||||
interceptor.PreRoundTrip(req)
|
||||
|
||||
sha := req.Header.Get(sidecar.HeaderBodySHA256)
|
||||
expectedEmpty := sidecar.BodySHA256(nil)
|
||||
if sha != expectedEmpty {
|
||||
t.Errorf("body SHA256 = %q, want empty-string SHA256 %q", sha, expectedEmpty)
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,31 @@ type Provider interface {
|
||||
//
|
||||
// The returned function (if non-nil) is called after the built-in chain
|
||||
// completes. Use it for logging, ending trace spans, or recording metrics.
|
||||
//
|
||||
// Body note: the middleware Clones the caller's request before invoking the
|
||||
// interceptor, which copies headers/URL/etc. but shares the underlying
|
||||
// io.ReadCloser. Extensions that read req.Body are responsible for restoring
|
||||
// a replayable body (e.g. via req.GetBody) before returning, otherwise the
|
||||
// built-in chain will see an exhausted stream.
|
||||
type Interceptor interface {
|
||||
PreRoundTrip(req *http.Request) func(resp *http.Response, err error)
|
||||
}
|
||||
|
||||
// AbortableInterceptor is an optional extension of Interceptor that lets an
|
||||
// extension reject a request before the built-in chain runs. Extensions that
|
||||
// implement this interface are detected by the built-in middleware via a
|
||||
// type assertion; both methods must be present, but when an extension
|
||||
// implements PreRoundTripE the middleware will NOT call PreRoundTrip.
|
||||
//
|
||||
// Returning a non-nil error from PreRoundTripE aborts the request: the
|
||||
// built-in chain is not executed and the middleware returns an *AbortError
|
||||
// wrapping the reason. The returned post function (if non-nil) is still
|
||||
// invoked with (nil, reason) so that extensions can unwind any state they
|
||||
// created in the pre hook (spans, metrics, audit records).
|
||||
//
|
||||
// Extensions that only care about the abortable variant can provide a no-op
|
||||
// PreRoundTrip method alongside PreRoundTripE to satisfy Interceptor.
|
||||
type AbortableInterceptor interface {
|
||||
Interceptor
|
||||
PreRoundTripE(req *http.Request) (post func(resp *http.Response, err error), err error)
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ func PollDeviceToken(ctx context.Context, httpClient *http.Client, appId, appSec
|
||||
errStr := getStr(data, "error")
|
||||
|
||||
if errStr == "" && getStr(data, "access_token") != "" {
|
||||
fmt.Fprintf(errOut, "[lark-cli] device-flow: token obtained successfully\n")
|
||||
fmt.Fprintf(errOut, "[lark-cli] device-flow: token response received\n")
|
||||
refreshToken := getStr(data, "refresh_token")
|
||||
tokenExpiresIn := getInt(data, "expires_in", 7200)
|
||||
refreshExpiresIn := getInt(data, "refresh_token_expires_in", 604800)
|
||||
|
||||
37
internal/cmdutil/completion.go
Normal file
37
internal/cmdutil/completion.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Cobra keeps completion callbacks in a package-global map keyed by
|
||||
// *pflag.Flag with no removal path, so registrations made for a *cobra.Command
|
||||
// outlive the command itself. Skip registration when the current invocation
|
||||
// will not serve a completion request.
|
||||
var flagCompletionsDisabled atomic.Bool
|
||||
|
||||
// SetFlagCompletionsDisabled switches RegisterFlagCompletion between
|
||||
// registering and no-op. Typically set once at process start.
|
||||
func SetFlagCompletionsDisabled(disabled bool) {
|
||||
flagCompletionsDisabled.Store(disabled)
|
||||
}
|
||||
|
||||
// FlagCompletionsDisabled reports the current switch state.
|
||||
func FlagCompletionsDisabled() bool {
|
||||
return flagCompletionsDisabled.Load()
|
||||
}
|
||||
|
||||
// RegisterFlagCompletion wraps (*cobra.Command).RegisterFlagCompletionFunc
|
||||
// and honors the package switch. The underlying error is swallowed to match
|
||||
// the `_ = cmd.RegisterFlagCompletionFunc(...)` style already used here.
|
||||
func RegisterFlagCompletion(cmd *cobra.Command, flagName string, fn cobra.CompletionFunc) {
|
||||
if flagCompletionsDisabled.Load() {
|
||||
return
|
||||
}
|
||||
_ = cmd.RegisterFlagCompletionFunc(flagName, fn)
|
||||
}
|
||||
78
internal/cmdutil/completion_test.go
Normal file
78
internal/cmdutil/completion_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestSetFlagCompletionsDisabled_RoundTrip(t *testing.T) {
|
||||
t.Cleanup(func() { SetFlagCompletionsDisabled(false) })
|
||||
|
||||
if FlagCompletionsDisabled() {
|
||||
t.Fatal("expected default false")
|
||||
}
|
||||
SetFlagCompletionsDisabled(true)
|
||||
if !FlagCompletionsDisabled() {
|
||||
t.Fatal("expected true after Set(true)")
|
||||
}
|
||||
SetFlagCompletionsDisabled(false)
|
||||
if FlagCompletionsDisabled() {
|
||||
t.Fatal("expected false after Set(false)")
|
||||
}
|
||||
}
|
||||
|
||||
// When disabled, a *cobra.Command must be collectable after the caller drops
|
||||
// its reference — i.e. the wrapper did not touch cobra's global map.
|
||||
func TestRegisterFlagCompletion_Disabled_DoesNotRetainCommand(t *testing.T) {
|
||||
SetFlagCompletionsDisabled(true)
|
||||
t.Cleanup(func() { SetFlagCompletionsDisabled(false) })
|
||||
|
||||
const N = 5
|
||||
var collected atomic.Int32
|
||||
func() {
|
||||
for range N {
|
||||
cmd := &cobra.Command{Use: "x"}
|
||||
cmd.Flags().String("foo", "", "")
|
||||
RegisterFlagCompletion(cmd, "foo", func(_ *cobra.Command, _ []string, _ string) ([]cobra.Completion, cobra.ShellCompDirective) {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
runtime.SetFinalizer(cmd, func(_ *cobra.Command) { collected.Add(1) })
|
||||
}
|
||||
}()
|
||||
// Finalizers run on a dedicated goroutine after GC; loop to give it time.
|
||||
for range 30 {
|
||||
runtime.GC()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
if got := collected.Load(); int(got) != N {
|
||||
t.Fatalf("expected %d *cobra.Command finalizers to fire when completions disabled, got %d", N, got)
|
||||
}
|
||||
}
|
||||
|
||||
// When enabled, the registered completion must be reachable via cobra.
|
||||
func TestRegisterFlagCompletion_Enabled_DoesRegister(t *testing.T) {
|
||||
SetFlagCompletionsDisabled(false)
|
||||
|
||||
cmd := &cobra.Command{Use: "x"}
|
||||
cmd.Flags().String("foo", "", "")
|
||||
want := []cobra.Completion{"a", "b"}
|
||||
RegisterFlagCompletion(cmd, "foo", func(_ *cobra.Command, _ []string, _ string) ([]cobra.Completion, cobra.ShellCompDirective) {
|
||||
return want, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
|
||||
fn, ok := cmd.GetFlagCompletionFunc("foo")
|
||||
if !ok {
|
||||
t.Fatal("expected completion func to be registered")
|
||||
}
|
||||
got, _ := fn(cmd, nil, "")
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
t.Fatalf("unexpected completion result: %v", got)
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
"golang.org/x/term"
|
||||
|
||||
extcred "github.com/larksuite/cli/extension/credential"
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
@@ -34,27 +32,24 @@ import (
|
||||
// Phase 2: Credential (sole data source for account info)
|
||||
// Phase 3: Config derived from Credential
|
||||
// Phase 4: LarkClient derived from Credential
|
||||
func NewDefault(inv InvocationContext) *Factory {
|
||||
func NewDefault(streams *IOStreams, inv InvocationContext) *Factory {
|
||||
streams = normalizeStreams(streams)
|
||||
f := &Factory{
|
||||
Keychain: keychain.Default(),
|
||||
Invocation: inv,
|
||||
}
|
||||
f.IOStreams = &IOStreams{
|
||||
In: os.Stdin,
|
||||
Out: os.Stdout,
|
||||
ErrOut: os.Stderr,
|
||||
IsTerminal: term.IsTerminal(int(os.Stdin.Fd())),
|
||||
IOStreams: streams,
|
||||
}
|
||||
|
||||
// Phase 0: FileIO provider (no dependency)
|
||||
f.FileIOProvider = fileio.GetProvider()
|
||||
|
||||
// Phase 1: HttpClient (no credential dependency)
|
||||
f.HttpClient = cachedHttpClientFunc()
|
||||
f.HttpClient = cachedHttpClientFunc(f)
|
||||
|
||||
// Phase 2: Credential (sole data source)
|
||||
// Keychain is read via closure so callers can replace f.Keychain after construction.
|
||||
f.Credential = buildCredentialProvider(credentialDeps{
|
||||
Keychain: f.Keychain,
|
||||
Keychain: func() keychain.KeychainAccess { return f.Keychain },
|
||||
Profile: inv.Profile,
|
||||
HttpClient: f.HttpClient,
|
||||
ErrOut: f.IOStreams.ErrOut,
|
||||
@@ -93,11 +88,11 @@ func safeRedirectPolicy(req *http.Request, via []*http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cachedHttpClientFunc() func() (*http.Client, error) {
|
||||
func cachedHttpClientFunc(f *Factory) func() (*http.Client, error) {
|
||||
return sync.OnceValues(func() (*http.Client, error) {
|
||||
util.WarnIfProxied(os.Stderr)
|
||||
util.WarnIfProxied(f.IOStreams.ErrOut)
|
||||
|
||||
var transport http.RoundTripper = util.NewBaseTransport()
|
||||
var transport http.RoundTripper = util.SharedTransport()
|
||||
transport = &RetryTransport{Base: transport}
|
||||
transport = &SecurityHeaderTransport{Base: transport}
|
||||
transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor
|
||||
@@ -122,7 +117,7 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) {
|
||||
lark.WithLogLevel(larkcore.LogLevelError),
|
||||
lark.WithHeaders(BaseSecurityHeaders()),
|
||||
}
|
||||
util.WarnIfProxied(os.Stderr)
|
||||
util.WarnIfProxied(f.IOStreams.ErrOut)
|
||||
opts = append(opts, lark.WithHttpClient(&http.Client{
|
||||
Transport: buildSDKTransport(),
|
||||
CheckRedirect: safeRedirectPolicy,
|
||||
@@ -134,15 +129,16 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) {
|
||||
}
|
||||
|
||||
func buildSDKTransport() http.RoundTripper {
|
||||
var sdkTransport http.RoundTripper = util.NewBaseTransport()
|
||||
var sdkTransport http.RoundTripper = util.SharedTransport()
|
||||
sdkTransport = &RetryTransport{Base: sdkTransport}
|
||||
sdkTransport = &UserAgentTransport{Base: sdkTransport}
|
||||
sdkTransport = &BuildHeaderTransport{Base: sdkTransport}
|
||||
sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport}
|
||||
return wrapWithExtension(sdkTransport)
|
||||
}
|
||||
|
||||
type credentialDeps struct {
|
||||
Keychain keychain.KeychainAccess
|
||||
Keychain func() keychain.KeychainAccess
|
||||
Profile string
|
||||
HttpClient func() (*http.Client, error)
|
||||
ErrOut io.Writer
|
||||
|
||||
@@ -6,14 +6,10 @@ package cmdutil
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
_ "github.com/larksuite/cli/extension/credential/env"
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
exttransport "github.com/larksuite/cli/extension/transport"
|
||||
internalauth "github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
@@ -63,7 +59,7 @@ func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) {
|
||||
t.Fatalf("SaveMultiAppConfig() error = %v", err)
|
||||
}
|
||||
|
||||
f := NewDefault(InvocationContext{Profile: "target"})
|
||||
f := NewDefault(nil, InvocationContext{Profile: "target"})
|
||||
if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeBot {
|
||||
t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeBot)
|
||||
}
|
||||
@@ -103,7 +99,7 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi
|
||||
t.Fatalf("SaveMultiAppConfig() error = %v", err)
|
||||
}
|
||||
|
||||
f := NewDefault(InvocationContext{Profile: "missing"})
|
||||
f := NewDefault(nil, InvocationContext{Profile: "missing"})
|
||||
if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff {
|
||||
t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeOff)
|
||||
}
|
||||
@@ -120,22 +116,6 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) {
|
||||
transport := buildSDKTransport()
|
||||
|
||||
sec, ok := transport.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport)
|
||||
}
|
||||
ua, ok := sec.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) {
|
||||
t.Setenv(envvars.CliAppID, "env-app")
|
||||
t.Setenv(envvars.CliAppSecret, "env-secret")
|
||||
@@ -144,7 +124,7 @@ func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) {
|
||||
t.Setenv(envvars.CliTenantAccessToken, "")
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
|
||||
f := NewDefault(InvocationContext{})
|
||||
f := NewDefault(nil, InvocationContext{})
|
||||
cmd := newCmdWithAsFlag("auto", false)
|
||||
|
||||
got := f.ResolveAs(context.Background(), cmd, "auto")
|
||||
@@ -164,7 +144,7 @@ func TestNewDefault_ConfigReturnsCliConfigCopyOfCredentialAccount(t *testing.T)
|
||||
t.Setenv(envvars.CliTenantAccessToken, "")
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
|
||||
f := NewDefault(InvocationContext{})
|
||||
f := NewDefault(nil, InvocationContext{})
|
||||
|
||||
acct, err := f.Credential.ResolveAccount(context.Background())
|
||||
if err != nil {
|
||||
@@ -189,7 +169,7 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin
|
||||
t.Setenv(envvars.CliTenantAccessToken, "")
|
||||
t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())
|
||||
|
||||
f := NewDefault(InvocationContext{})
|
||||
f := NewDefault(nil, InvocationContext{})
|
||||
|
||||
acct, err := f.Credential.ResolveAccount(context.Background())
|
||||
if err != nil {
|
||||
@@ -217,7 +197,7 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing.
|
||||
fileio.Register(provider)
|
||||
t.Cleanup(func() { fileio.Register(prev) })
|
||||
|
||||
f := NewDefault(InvocationContext{})
|
||||
f := NewDefault(nil, InvocationContext{})
|
||||
if f.FileIOProvider != provider {
|
||||
t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider)
|
||||
}
|
||||
@@ -232,170 +212,3 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing.
|
||||
t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls)
|
||||
}
|
||||
}
|
||||
|
||||
type stubTransportProvider struct {
|
||||
interceptor exttransport.Interceptor
|
||||
}
|
||||
|
||||
func (s *stubTransportProvider) Name() string { return "stub" }
|
||||
func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor {
|
||||
if s.interceptor != nil {
|
||||
return s.interceptor
|
||||
}
|
||||
return &stubTransportImpl{}
|
||||
}
|
||||
|
||||
type stubTransportImpl struct{}
|
||||
|
||||
func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerCapturingInterceptor sets custom headers in PreRoundTrip and records
|
||||
// whether PostRoundTrip was called, to verify execution order.
|
||||
type headerCapturingInterceptor struct {
|
||||
preCalled bool
|
||||
postCalled bool
|
||||
}
|
||||
|
||||
func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) {
|
||||
h.preCalled = true
|
||||
// Set a custom header that should survive (no built-in override)
|
||||
req.Header.Set("X-Custom-Trace", "ext-trace-123")
|
||||
// Try to override a security header — should be overwritten by SecurityHeaderTransport
|
||||
req.Header.Set(HeaderSource, "ext-tampered")
|
||||
return func(resp *http.Response, err error) {
|
||||
h.postCalled = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionInterceptor_ExecutionOrder(t *testing.T) {
|
||||
var receivedHeaders http.Header
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ic := &headerCapturingInterceptor{}
|
||||
exttransport.Register(&stubTransportProvider{interceptor: ic})
|
||||
t.Cleanup(func() { exttransport.Register(nil) })
|
||||
|
||||
// Use HTTP transport chain (has SecurityHeaderTransport)
|
||||
var base http.RoundTripper = http.DefaultTransport
|
||||
base = &RetryTransport{Base: base}
|
||||
base = &SecurityHeaderTransport{Base: base}
|
||||
transport := wrapWithExtension(base)
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// PreRoundTrip was called
|
||||
if !ic.preCalled {
|
||||
t.Fatal("PreRoundTrip was not called")
|
||||
}
|
||||
// PostRoundTrip (closure) was called
|
||||
if !ic.postCalled {
|
||||
t.Fatal("PostRoundTrip closure was not called")
|
||||
}
|
||||
// Custom header set by extension survives (no built-in override)
|
||||
if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" {
|
||||
t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123")
|
||||
}
|
||||
// Security header overridden by extension is restored by SecurityHeaderTransport
|
||||
if got := receivedHeaders.Get(HeaderSource); got != SourceValue {
|
||||
t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) {
|
||||
type ctxKeyType string
|
||||
const testKey ctxKeyType = "original"
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
var ctxValue any
|
||||
|
||||
// Use a custom transport that captures the context value seen by the built-in chain
|
||||
capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
ctxValue = req.Context().Value(testKey)
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
})
|
||||
|
||||
// Interceptor that tries to tamper with context
|
||||
tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) {
|
||||
// Try to replace context with a new one
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered"))
|
||||
return nil
|
||||
})
|
||||
|
||||
mid := &extensionMiddleware{Base: capturer, Ext: tamperIC}
|
||||
|
||||
origCtx := context.WithValue(context.Background(), testKey, "original")
|
||||
req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil)
|
||||
resp, err := mid.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Built-in chain should see original context, not tampered
|
||||
if ctxValue != "original" {
|
||||
t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original")
|
||||
}
|
||||
}
|
||||
|
||||
// interceptorFunc adapts a function to exttransport.Interceptor.
|
||||
type interceptorFunc func(*http.Request) func(*http.Response, error)
|
||||
|
||||
func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) }
|
||||
|
||||
func TestBuildSDKTransport_WithExtension(t *testing.T) {
|
||||
exttransport.Register(&stubTransportProvider{})
|
||||
t.Cleanup(func() { exttransport.Register(nil) })
|
||||
|
||||
transport := buildSDKTransport()
|
||||
|
||||
// Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base
|
||||
mid, ok := transport.(*extensionMiddleware)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport)
|
||||
}
|
||||
sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base)
|
||||
}
|
||||
ua, ok := sec.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSDKTransport_WithoutExtension(t *testing.T) {
|
||||
exttransport.Register(nil)
|
||||
|
||||
transport := buildSDKTransport()
|
||||
|
||||
sec, ok := transport.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport)
|
||||
}
|
||||
ua, ok := sec.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) {
|
||||
fn := cachedHttpClientFunc()
|
||||
fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}})
|
||||
|
||||
c1, err := fn()
|
||||
if err != nil {
|
||||
@@ -28,7 +29,7 @@ func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCachedHttpClientFunc_HasTimeout(t *testing.T) {
|
||||
fn := cachedHttpClientFunc()
|
||||
fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}})
|
||||
c, _ := fn()
|
||||
if c.Timeout == 0 {
|
||||
t.Error("expected non-zero timeout")
|
||||
@@ -36,7 +37,7 @@ func TestCachedHttpClientFunc_HasTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCachedHttpClientFunc_HasRedirectPolicy(t *testing.T) {
|
||||
fn := cachedHttpClientFunc()
|
||||
fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}})
|
||||
c, _ := fn()
|
||||
if c.CheckRedirect == nil {
|
||||
t.Error("expected CheckRedirect to be set (safeRedirectPolicy)")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
@@ -122,9 +123,22 @@ func BuildFormdata(fileIO fileio.FileIO, fieldName, filePath string, isStdin boo
|
||||
// Add top-level JSON keys as text form fields.
|
||||
if m, ok := dataJSON.(map[string]any); ok {
|
||||
for k, v := range m {
|
||||
fd.AddField(k, fmt.Sprintf("%v", v))
|
||||
fd.AddField(k, formatFormFieldValue(v))
|
||||
}
|
||||
}
|
||||
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// formatFormFieldValue renders a JSON-unmarshalled value as a multipart form
|
||||
// field string. float64 is handled specially: fmt's default %v/%g switches to
|
||||
// scientific notation for values >= ~1e6 (e.g. "1.185356e+06"), which some
|
||||
// backends reject when parsing the field as an integer. Use decimal notation
|
||||
// instead so size / block_num / offset-style numeric fields round-trip cleanly.
|
||||
// All other types fall through to %v.
|
||||
func formatFormFieldValue(v any) string {
|
||||
if n, ok := v.(float64); ok {
|
||||
return strconv.FormatFloat(n, 'f', -1, 64)
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
@@ -336,3 +336,40 @@ func TestBuildFormdata(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFormatFormFieldValue locks in the fix for the float64 -> scientific
|
||||
// notation bug. JSON numbers unmarshal to float64, and fmt's default %v for
|
||||
// float64 delegates to %g which switches to scientific notation at ~1e6
|
||||
// (e.g. 1185356 -> "1.185356e+06"). Backends that parse the form field as an
|
||||
// integer reject that, surfacing as a generic "params error".
|
||||
func TestFormatFormFieldValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
want string
|
||||
}{
|
||||
{"float64 large integer avoids scientific", float64(1185356), "1185356"},
|
||||
{"float64 below scientific threshold", float64(358934), "358934"},
|
||||
{"float64 zero", float64(0), "0"},
|
||||
{"float64 huge", float64(20 * 1024 * 1024), "20971520"},
|
||||
{"float64 negative", float64(-42), "-42"},
|
||||
{"float64 fractional preserved", float64(3.14), "3.14"},
|
||||
{"string pass-through", "hello", "hello"},
|
||||
{"bool true", true, "true"},
|
||||
{"int via %v", 42, "42"},
|
||||
{"int64 via %v", int64(9007199254740992), "9007199254740992"},
|
||||
}
|
||||
|
||||
for _, temp := range tests {
|
||||
tt := temp
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatFormFieldValue(tt.in)
|
||||
if got != tt.want {
|
||||
t.Fatalf("formatFormFieldValue(%v) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
68
internal/cmdutil/identity_flag.go
Normal file
68
internal/cmdutil/identity_flag.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// AddAPIIdentityFlag registers the standard --as flag shape used by api/service commands.
|
||||
func AddAPIIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string) {
|
||||
addIdentityFlag(ctx, cmd, f, target, identityFlagConfig{
|
||||
defaultValue: "auto",
|
||||
usage: "identity type: user | bot | auto (default)",
|
||||
completionValues: []string{"user", "bot"},
|
||||
})
|
||||
}
|
||||
|
||||
// AddShortcutIdentityFlag registers the standard --as flag shape used by shortcuts.
|
||||
func AddShortcutIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, authTypes []string) {
|
||||
if len(authTypes) == 0 {
|
||||
authTypes = []string{"user"}
|
||||
}
|
||||
addIdentityFlag(ctx, cmd, f, nil, identityFlagConfig{
|
||||
defaultValue: authTypes[0],
|
||||
usage: "identity type: " + strings.Join(authTypes, " | "),
|
||||
completionValues: authTypes,
|
||||
})
|
||||
}
|
||||
|
||||
type identityFlagConfig struct {
|
||||
defaultValue string
|
||||
usage string
|
||||
completionValues []string
|
||||
}
|
||||
|
||||
// addIdentityFlag centralizes --as registration and strict-mode UX.
|
||||
// When strict mode is active, the flag is still accepted for compatibility
|
||||
// but hidden from help/completion and locked to the forced identity by default.
|
||||
func addIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string, cfg identityFlagConfig) {
|
||||
if forced := f.ResolveStrictMode(ctx).ForcedIdentity(); forced != "" {
|
||||
// Keep registering --as in strict mode even though it is hidden.
|
||||
// This preserves parser compatibility for existing invocations that still pass
|
||||
// --as, and keeps downstream GetString("as") / ResolveAs paths stable.
|
||||
// The usage text below is effectively placeholder text because the flag is hidden.
|
||||
registerIdentityFlag(cmd, target, string(forced),
|
||||
fmt.Sprintf("identity locked to %s by strict mode (admin-managed)", forced))
|
||||
_ = cmd.Flags().MarkHidden("as")
|
||||
return
|
||||
}
|
||||
|
||||
registerIdentityFlag(cmd, target, cfg.defaultValue, cfg.usage)
|
||||
RegisterFlagCompletion(cmd, "as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return cfg.completionValues, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
}
|
||||
|
||||
func registerIdentityFlag(cmd *cobra.Command, target *string, defaultValue, usage string) {
|
||||
if target != nil {
|
||||
cmd.Flags().StringVar(target, "as", defaultValue, usage)
|
||||
return
|
||||
}
|
||||
cmd.Flags().String("as", defaultValue, usage)
|
||||
}
|
||||
68
internal/cmdutil/identity_flag_test.go
Normal file
68
internal/cmdutil/identity_flag_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestAddAPIIdentityFlag_NonStrictMode(t *testing.T) {
|
||||
f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"})
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
|
||||
AddAPIIdentityFlag(context.Background(), cmd, f, nil)
|
||||
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if flag.Hidden {
|
||||
t.Fatal("expected --as flag to be visible outside strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "auto" {
|
||||
t.Fatalf("default value = %q, want %q", got, "auto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddAPIIdentityFlag_StrictModeHidesFlagAndLocksDefault(t *testing.T) {
|
||||
f, _, _, _ := TestFactory(t, &core.CliConfig{
|
||||
AppID: "a", AppSecret: "s", SupportedIdentities: 2,
|
||||
})
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
|
||||
AddAPIIdentityFlag(context.Background(), cmd, f, nil)
|
||||
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("expected --as flag to be hidden in strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "bot" {
|
||||
t.Fatalf("default value = %q, want %q", got, "bot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddShortcutIdentityFlag_UsesAuthTypes(t *testing.T) {
|
||||
f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"})
|
||||
cmd := &cobra.Command{Use: "test"}
|
||||
|
||||
AddShortcutIdentityFlag(context.Background(), cmd, f, []string{"bot"})
|
||||
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if flag.Hidden {
|
||||
t.Fatal("expected --as flag to be visible outside strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "bot" {
|
||||
t.Fatalf("default value = %q, want %q", got, "bot")
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,12 @@
|
||||
|
||||
package cmdutil
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// IOStreams provides the standard input/output/error streams.
|
||||
// Commands should use these instead of os.Stdin/Stdout/Stderr
|
||||
@@ -14,3 +19,45 @@ type IOStreams struct {
|
||||
ErrOut io.Writer
|
||||
IsTerminal bool
|
||||
}
|
||||
|
||||
// NewIOStreams builds an IOStreams from arbitrary readers/writers.
|
||||
// IsTerminal is derived from in's underlying *os.File, if any; non-file
|
||||
// readers (bytes.Buffer, strings.Reader, …) yield IsTerminal=false.
|
||||
func NewIOStreams(in io.Reader, out, errOut io.Writer) *IOStreams {
|
||||
isTerminal := false
|
||||
if f, ok := in.(*os.File); ok {
|
||||
isTerminal = term.IsTerminal(int(f.Fd()))
|
||||
}
|
||||
return &IOStreams{In: in, Out: out, ErrOut: errOut, IsTerminal: isTerminal}
|
||||
}
|
||||
|
||||
// SystemIO creates an IOStreams wired to the process's standard file descriptors.
|
||||
//
|
||||
//nolint:forbidigo // entry point for real stdio
|
||||
func SystemIO() *IOStreams {
|
||||
return NewIOStreams(os.Stdin, os.Stdout, os.Stderr)
|
||||
}
|
||||
|
||||
// normalizeStreams returns a fresh IOStreams with any nil field filled from
|
||||
// SystemIO(). Callers constructing a partial struct like &IOStreams{Out: buf}
|
||||
// get a usable result without nil writers leaking into RoundTripper warnings,
|
||||
// Cobra I/O, or credential-provider error paths.
|
||||
func normalizeStreams(s *IOStreams) *IOStreams {
|
||||
if s == nil {
|
||||
return SystemIO()
|
||||
}
|
||||
out := *s
|
||||
if out.In == nil || out.Out == nil || out.ErrOut == nil {
|
||||
sys := SystemIO()
|
||||
if out.In == nil {
|
||||
out.In = sys.In
|
||||
}
|
||||
if out.Out == nil {
|
||||
out.Out = sys.Out
|
||||
}
|
||||
if out.ErrOut == nil {
|
||||
out.ErrOut = sys.ErrOut
|
||||
}
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRetryTransport_NoRetry(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base, MaxRetries: 0}
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("expected 1 call, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryTransport_RetryOn500(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil
|
||||
}
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond}
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("expected 200 after retries, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryTransport_DefaultNoRetry(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base} // default MaxRetries=0
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 500 {
|
||||
t.Errorf("expected 500 with no retries, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("expected 1 call with default config, got %d", calls)
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,14 @@ package cmdutil
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/larksuite/cli/extension/credential"
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
exttransport "github.com/larksuite/cli/extension/transport"
|
||||
"github.com/larksuite/cli/internal/build"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
)
|
||||
@@ -14,12 +21,21 @@ import (
|
||||
const (
|
||||
HeaderSource = "X-Cli-Source"
|
||||
HeaderVersion = "X-Cli-Version"
|
||||
HeaderBuild = "X-Cli-Build"
|
||||
HeaderShortcut = "X-Cli-Shortcut"
|
||||
HeaderExecutionId = "X-Cli-Execution-Id"
|
||||
|
||||
SourceValue = "lark-cli"
|
||||
|
||||
HeaderUserAgent = "User-Agent"
|
||||
|
||||
// BuildKindOfficial / BuildKindExtended / BuildKindUnknown are the values
|
||||
// reported in the X-Cli-Build header; see DetectBuildKind for semantics.
|
||||
BuildKindOfficial = "official"
|
||||
BuildKindExtended = "extended"
|
||||
BuildKindUnknown = "unknown"
|
||||
|
||||
officialModulePath = "github.com/larksuite/cli"
|
||||
)
|
||||
|
||||
// UserAgentValue returns the User-Agent value: "lark-cli/{version}".
|
||||
@@ -32,10 +48,108 @@ func BaseSecurityHeaders() http.Header {
|
||||
h := make(http.Header)
|
||||
h.Set(HeaderSource, SourceValue)
|
||||
h.Set(HeaderVersion, build.Version)
|
||||
h.Set(HeaderBuild, DetectBuildKind())
|
||||
h.Set(HeaderUserAgent, UserAgentValue())
|
||||
return h
|
||||
}
|
||||
|
||||
var (
|
||||
buildKindOnce sync.Once
|
||||
buildKindVal string
|
||||
)
|
||||
|
||||
// DetectBuildKind reports whether this binary is the official CLI, an
|
||||
// extended/repackaged build, or unknown. The result is cached via sync.Once
|
||||
// so it is computed only on the first call.
|
||||
//
|
||||
// IMPORTANT: must NOT be called from any package init(). Go's init ordering
|
||||
// follows the import graph; ISV providers registered via blank import may not
|
||||
// have run yet, which would misclassify an extended build as official. Call
|
||||
// only when handling an actual request (e.g. from BaseSecurityHeaders).
|
||||
func DetectBuildKind() string {
|
||||
buildKindOnce.Do(func() {
|
||||
buildKindVal = computeBuildKind()
|
||||
})
|
||||
return buildKindVal
|
||||
}
|
||||
|
||||
// computeBuildKind performs the actual detection without any caching.
|
||||
// Exposed for tests. Gathers runtime/global inputs and delegates the pure
|
||||
// branching logic to classifyBuild so that logic can be unit-tested without
|
||||
// mutating process-wide provider registries.
|
||||
func computeBuildKind() string {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
mainPath := ""
|
||||
if ok {
|
||||
mainPath = info.Main.Path
|
||||
}
|
||||
|
||||
credProviders := credential.Providers()
|
||||
creds := make([]any, len(credProviders))
|
||||
for i, p := range credProviders {
|
||||
creds[i] = p
|
||||
}
|
||||
|
||||
var tp any
|
||||
if p := exttransport.GetProvider(); p != nil {
|
||||
tp = p
|
||||
}
|
||||
var fp any
|
||||
if p := fileio.GetProvider(); p != nil {
|
||||
fp = p
|
||||
}
|
||||
return classifyBuild(mainPath, ok, creds, tp, fp)
|
||||
}
|
||||
|
||||
// classifyBuild is the pure classification logic used by computeBuildKind.
|
||||
// Callers supply concrete values so every branch is reachable from tests
|
||||
// without touching debug.ReadBuildInfo or the extension registries.
|
||||
//
|
||||
// Priority order mirrors the design doc:
|
||||
// 1. no build info → unknown
|
||||
// 2. main module path not the official one → extended (ISV wrapper)
|
||||
// 3. any non-builtin provider (credential / transport / fileio) → extended
|
||||
// 4. otherwise → official
|
||||
func classifyBuild(mainPath string, haveBuildInfo bool, credProviders []any, transportProvider, fileioProvider any) string {
|
||||
if !haveBuildInfo {
|
||||
return BuildKindUnknown
|
||||
}
|
||||
if mainPath != "" && mainPath != officialModulePath {
|
||||
return BuildKindExtended
|
||||
}
|
||||
for _, p := range credProviders {
|
||||
if !isBuiltinProvider(p) {
|
||||
return BuildKindExtended
|
||||
}
|
||||
}
|
||||
if transportProvider != nil && !isBuiltinProvider(transportProvider) {
|
||||
return BuildKindExtended
|
||||
}
|
||||
if fileioProvider != nil && !isBuiltinProvider(fileioProvider) {
|
||||
return BuildKindExtended
|
||||
}
|
||||
return BuildKindOfficial
|
||||
}
|
||||
|
||||
// isBuiltinProvider reports whether p is declared under the official module
|
||||
// path. Third-party providers live under their own module and fail this check.
|
||||
// Using reflect.PkgPath makes this robust against Name() spoofing since
|
||||
// package paths are fixed at compile time.
|
||||
func isBuiltinProvider(p any) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
t := reflect.TypeOf(p)
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
pkg := t.PkgPath()
|
||||
return pkg == officialModulePath || strings.HasPrefix(pkg, officialModulePath+"/")
|
||||
}
|
||||
|
||||
// ── Context utilities ──
|
||||
|
||||
type ctxKey string
|
||||
|
||||
34
internal/cmdutil/secheader_sidecar_test.go
Normal file
34
internal/cmdutil/secheader_sidecar_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
sidecarcred "github.com/larksuite/cli/extension/credential/sidecar"
|
||||
sidecartrans "github.com/larksuite/cli/extension/transport/sidecar"
|
||||
)
|
||||
|
||||
// TestIsBuiltinProvider_SidecarProviders locks the classification for the
|
||||
// sidecar-mode providers enumerated in design doc §3.3.2 as "官方自带". These
|
||||
// types only compile when the `authsidecar` build tag is active, so the test
|
||||
// is guarded by the same tag.
|
||||
func TestIsBuiltinProvider_SidecarProviders(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
provider any
|
||||
}{
|
||||
{"sidecar credential provider", &sidecarcred.Provider{}},
|
||||
{"sidecar transport provider", &sidecartrans.Provider{}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !isBuiltinProvider(tc.provider) {
|
||||
t.Fatalf("%T must be classified as builtin (PkgPath under %s)", tc.provider, officialModulePath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
262
internal/cmdutil/secheader_test.go
Normal file
262
internal/cmdutil/secheader_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/extension/credential"
|
||||
envcred "github.com/larksuite/cli/extension/credential/env"
|
||||
"github.com/larksuite/cli/internal/vfs/localfileio"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isBuiltinProvider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// cmdutilLocalProvider has PkgPath under the official module
|
||||
// ("github.com/larksuite/cli/internal/cmdutil") and should be classified
|
||||
// as builtin.
|
||||
type cmdutilLocalProvider struct{}
|
||||
|
||||
// Name intentionally returns a value that mimics an external provider; the
|
||||
// PkgPath-based classifier must ignore it. See TestIsBuiltinProvider_PkgPathNotSpoofableByName.
|
||||
func (cmdutilLocalProvider) Name() string { return "external-spoofed-provider" }
|
||||
func (cmdutilLocalProvider) ResolveAccount(context.Context) (*credential.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (cmdutilLocalProvider) ResolveToken(context.Context, credential.TokenSpec) (*credential.Token, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestIsBuiltinProvider_Nil(t *testing.T) {
|
||||
if isBuiltinProvider(nil) {
|
||||
t.Fatal("isBuiltinProvider(nil) = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBuiltinProvider_TypeUnderOfficialModule(t *testing.T) {
|
||||
if !isBuiltinProvider(&cmdutilLocalProvider{}) {
|
||||
t.Fatal("type under github.com/larksuite/cli/... should be builtin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBuiltinProvider_StdlibTypeIsNotBuiltin(t *testing.T) {
|
||||
// A standard library type has PkgPath "net/http" — outside official module.
|
||||
// This covers the non-builtin branch, which we cannot trigger from inside
|
||||
// this test file using a locally-defined type.
|
||||
if isBuiltinProvider(&http.Server{}) {
|
||||
t.Fatal("stdlib type classified as builtin, PkgPath check is broken")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBuiltinProvider_PkgPathNotSpoofableByName(t *testing.T) {
|
||||
// Name() returns a string, but classification uses reflect.Type.PkgPath
|
||||
// which is compile-time fixed. The local type returns a name that looks
|
||||
// like an ISV provider; it must still classify as builtin.
|
||||
p := &cmdutilLocalProvider{}
|
||||
if p.Name() != "external-spoofed-provider" {
|
||||
t.Fatalf("sanity check: Name() = %q, spoof value lost", p.Name())
|
||||
}
|
||||
if !isBuiltinProvider(p) {
|
||||
t.Fatal("isBuiltinProvider should decide by PkgPath, not Name()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsBuiltinProvider_NonPointerValues covers the non-pointer reflect branch.
|
||||
// The existing tests only exercise pointer receivers (&T{}); when a provider
|
||||
// is passed by value the reflect.Kind is not Ptr and t.Elem() is skipped.
|
||||
func TestIsBuiltinProvider_NonPointerValues(t *testing.T) {
|
||||
if !isBuiltinProvider(cmdutilLocalProvider{}) {
|
||||
t.Fatal("non-pointer local type should be builtin (PkgPath still under official module)")
|
||||
}
|
||||
// http.Server as a non-pointer — PkgPath "net/http", not under official.
|
||||
if isBuiltinProvider(http.Server{}) {
|
||||
t.Fatal("non-pointer stdlib type should not be builtin")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsBuiltinProvider_RealBuiltinProviders locks down the classification
|
||||
// for the concrete providers enumerated in design doc §3.3.2 as "官方自带":
|
||||
// env credential provider and local fileio provider. If any of these is
|
||||
// moved out of the official module tree in the future, this test must flip
|
||||
// red so the new package path is explicitly considered.
|
||||
//
|
||||
// The sidecar providers (extension/credential/sidecar and
|
||||
// extension/transport/sidecar) are guarded by the `authsidecar` build tag
|
||||
// and covered in secheader_sidecar_test.go under that tag.
|
||||
func TestIsBuiltinProvider_RealBuiltinProviders(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
provider any
|
||||
}{
|
||||
{"env credential provider", &envcred.Provider{}},
|
||||
{"local fileio provider", &localfileio.Provider{}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !isBuiltinProvider(tc.provider) {
|
||||
t.Fatalf("%T must be classified as builtin (PkgPath under %s)", tc.provider, officialModulePath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// computeBuildKind
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeBuildKind_ReturnsKnownValue(t *testing.T) {
|
||||
// Under `go test`, Main.Path is typically the module being tested
|
||||
// ("github.com/larksuite/cli"); the concrete return may still be
|
||||
// official, extended, or unknown depending on Main.Path and the
|
||||
// registered providers. Just assert it's one of the defined values.
|
||||
got := computeBuildKind()
|
||||
switch got {
|
||||
case BuildKindOfficial, BuildKindExtended, BuildKindUnknown:
|
||||
default:
|
||||
t.Fatalf("computeBuildKind() = %q, want one of official/extended/unknown", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyBuild — pure branching logic
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// These tests cover every branch of classifyBuild with explicit inputs,
|
||||
// which is impossible from computeBuildKind alone because debug.ReadBuildInfo
|
||||
// and the process-wide provider registries can't be reshaped in a test.
|
||||
|
||||
func TestClassifyBuild_NoBuildInfo_ReturnsUnknown(t *testing.T) {
|
||||
if got := classifyBuild("", false, nil, nil, nil); got != BuildKindUnknown {
|
||||
t.Fatalf("classifyBuild(haveBuildInfo=false) = %q, want %q", got, BuildKindUnknown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_ExtendedMainPath_ReturnsExtended(t *testing.T) {
|
||||
cases := []string{
|
||||
"github.com/acme/lark-cli-wrapper",
|
||||
"example.com/isv/lark",
|
||||
"gitlab.mycorp.internal/tools/lark-cli-fork",
|
||||
}
|
||||
for _, mp := range cases {
|
||||
t.Run(mp, func(t *testing.T) {
|
||||
if got := classifyBuild(mp, true, nil, nil, nil); got != BuildKindExtended {
|
||||
t.Fatalf("mainPath=%q classifyBuild = %q, want %q", mp, got, BuildKindExtended)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_OfficialMainPath_NoProviders_ReturnsOfficial(t *testing.T) {
|
||||
if got := classifyBuild(officialModulePath, true, nil, nil, nil); got != BuildKindOfficial {
|
||||
t.Fatalf("classifyBuild(official, no providers) = %q, want %q", got, BuildKindOfficial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_EmptyMainPath_DoesNotTriggerExtended(t *testing.T) {
|
||||
// An empty Main.Path (rare, e.g. `go run` pre-1.18) must not be treated
|
||||
// as extended by itself — the classifier falls through to provider checks.
|
||||
if got := classifyBuild("", true, nil, nil, nil); got != BuildKindOfficial {
|
||||
t.Fatalf("classifyBuild(empty mainPath, no providers) = %q, want %q", got, BuildKindOfficial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_NonBuiltinCredentialProvider_ReturnsExtended(t *testing.T) {
|
||||
// Any non-builtin credential provider flips the verdict to extended.
|
||||
got := classifyBuild(officialModulePath, true, []any{&http.Server{}}, nil, nil)
|
||||
if got != BuildKindExtended {
|
||||
t.Fatalf("classifyBuild with external credential = %q, want %q", got, BuildKindExtended)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_MixedCredentialProviders_ExtendedWins(t *testing.T) {
|
||||
// Even if most providers are builtin, a single external one decides.
|
||||
providers := []any{&cmdutilLocalProvider{}, &http.Server{}}
|
||||
if got := classifyBuild(officialModulePath, true, providers, nil, nil); got != BuildKindExtended {
|
||||
t.Fatalf("classifyBuild mixed providers = %q, want %q", got, BuildKindExtended)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_NonBuiltinTransportProvider_ReturnsExtended(t *testing.T) {
|
||||
got := classifyBuild(officialModulePath, true, nil, &http.Server{}, nil)
|
||||
if got != BuildKindExtended {
|
||||
t.Fatalf("classifyBuild with external transport = %q, want %q", got, BuildKindExtended)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_NonBuiltinFileioProvider_ReturnsExtended(t *testing.T) {
|
||||
got := classifyBuild(officialModulePath, true, nil, nil, &http.Server{})
|
||||
if got != BuildKindExtended {
|
||||
t.Fatalf("classifyBuild with external fileio = %q, want %q", got, BuildKindExtended)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyBuild_AllBuiltinProviders_ReturnsOfficial(t *testing.T) {
|
||||
// All three slots filled with builtin providers must still classify as official.
|
||||
got := classifyBuild(
|
||||
officialModulePath, true,
|
||||
[]any{&cmdutilLocalProvider{}},
|
||||
&cmdutilLocalProvider{},
|
||||
&cmdutilLocalProvider{},
|
||||
)
|
||||
if got != BuildKindOfficial {
|
||||
t.Fatalf("classifyBuild all-builtin = %q, want %q", got, BuildKindOfficial)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClassifyBuild_MainPathPriorityOverProviders documents that the main
|
||||
// module path takes precedence: even with only builtin providers, a non-
|
||||
// official main path still yields extended.
|
||||
func TestClassifyBuild_MainPathPriorityOverProviders(t *testing.T) {
|
||||
got := classifyBuild(
|
||||
"github.com/acme/lark-wrapper", true,
|
||||
[]any{&cmdutilLocalProvider{}},
|
||||
&cmdutilLocalProvider{},
|
||||
&cmdutilLocalProvider{},
|
||||
)
|
||||
if got != BuildKindExtended {
|
||||
t.Fatalf("main-path override failed: got %q, want %q", got, BuildKindExtended)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DetectBuildKind — sync.Once caching
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestDetectBuildKind_StableAcrossCalls(t *testing.T) {
|
||||
a := DetectBuildKind()
|
||||
b := DetectBuildKind()
|
||||
if a != b {
|
||||
t.Fatalf("DetectBuildKind() returned different values on repeat: %q vs %q", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BaseSecurityHeaders
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBaseSecurityHeaders_IncludesBuildHeader(t *testing.T) {
|
||||
h := BaseSecurityHeaders()
|
||||
v := h.Get(HeaderBuild)
|
||||
if v == "" {
|
||||
t.Fatal("BaseSecurityHeaders missing X-Cli-Build header")
|
||||
}
|
||||
switch v {
|
||||
case BuildKindOfficial, BuildKindExtended, BuildKindUnknown:
|
||||
default:
|
||||
t.Fatalf("X-Cli-Build = %q, want one of official/extended/unknown", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseSecurityHeaders_AllRequiredHeaders(t *testing.T) {
|
||||
h := BaseSecurityHeaders()
|
||||
for _, key := range []string{HeaderSource, HeaderVersion, HeaderBuild, HeaderUserAgent} {
|
||||
if h.Get(key) == "" {
|
||||
t.Errorf("BaseSecurityHeaders missing %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,24 @@ func (t *UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
|
||||
return util.FallbackTransport().RoundTrip(req)
|
||||
}
|
||||
|
||||
// BuildHeaderTransport is an http.RoundTripper that force-writes the
|
||||
// X-Cli-Build header before every request. Used in the SDK transport chain,
|
||||
// where SecurityHeaderTransport is not installed, to prevent extensions from
|
||||
// tampering with the build classification. The direct HTTP chain is already
|
||||
// covered by SecurityHeaderTransport iterating BaseSecurityHeaders.
|
||||
type BuildHeaderTransport struct {
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *BuildHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set(HeaderBuild, DetectBuildKind())
|
||||
if t.Base != nil {
|
||||
return t.Base.RoundTrip(req)
|
||||
}
|
||||
return util.FallbackTransport().RoundTrip(req)
|
||||
}
|
||||
|
||||
// SecurityHeaderTransport is an http.RoundTripper that injects CLI security
|
||||
// headers into every request. Shortcut headers are read from the request context.
|
||||
type SecurityHeaderTransport struct {
|
||||
@@ -104,20 +122,47 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response,
|
||||
}
|
||||
|
||||
// extensionMiddleware wraps the built-in transport chain with pre/post hooks.
|
||||
// The built-in chain always executes and cannot be skipped or overridden.
|
||||
// The original request context is restored after PreRoundTrip to prevent
|
||||
// The built-in chain always executes unless the extension is an
|
||||
// exttransport.AbortableInterceptor and its PreRoundTripE returns a non-nil
|
||||
// error; it cannot otherwise be skipped or overridden.
|
||||
//
|
||||
// The original request context is restored after the pre hook to prevent
|
||||
// extensions from tampering with cancellation, deadlines, or built-in values.
|
||||
// Cloning the request isolates header/URL/etc. mutations from the caller's
|
||||
// request object; req.Body is intentionally shared — extensions that consume
|
||||
// it are responsible for rewinding (see Interceptor doc).
|
||||
type extensionMiddleware struct {
|
||||
Base http.RoundTripper
|
||||
Ext exttransport.Interceptor
|
||||
Base http.RoundTripper
|
||||
Ext exttransport.Interceptor
|
||||
ExtName string // Provider.Name(), captured at wrap time for *AbortError.Extension
|
||||
}
|
||||
|
||||
// RoundTrip calls PreRoundTrip, restores the original context, executes
|
||||
// the built-in chain, then calls the post hook if non-nil.
|
||||
// RoundTrip invokes the interceptor pre hook, restores the original context,
|
||||
// executes the built-in chain (unless aborted), then calls the post hook if
|
||||
// non-nil. When the extension implements AbortableInterceptor and returns a
|
||||
// non-nil error from PreRoundTripE, the built-in chain is skipped and an
|
||||
// *exttransport.AbortError is returned; the post hook is still invoked with
|
||||
// (nil, reason) so extensions can unwind resources.
|
||||
func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
origCtx := req.Context()
|
||||
req = req.Clone(origCtx) // isolate caller's request before extension mutations
|
||||
post := m.Ext.PreRoundTrip(req)
|
||||
req = req.Clone(origCtx)
|
||||
|
||||
var (
|
||||
post func(*http.Response, error)
|
||||
abortEr error
|
||||
)
|
||||
if a, ok := m.Ext.(exttransport.AbortableInterceptor); ok {
|
||||
post, abortEr = a.PreRoundTripE(req)
|
||||
} else {
|
||||
post = m.Ext.PreRoundTrip(req)
|
||||
}
|
||||
if abortEr != nil {
|
||||
if post != nil {
|
||||
post(nil, abortEr)
|
||||
}
|
||||
return nil, &exttransport.AbortError{Extension: m.ExtName, Reason: abortEr}
|
||||
}
|
||||
|
||||
req = req.WithContext(origCtx) // restore original context
|
||||
resp, err := m.Base.RoundTrip(req)
|
||||
if post != nil {
|
||||
@@ -137,5 +182,5 @@ func wrapWithExtension(transport http.RoundTripper) http.RoundTripper {
|
||||
if tr == nil {
|
||||
return transport
|
||||
}
|
||||
return &extensionMiddleware{Base: transport, Ext: tr}
|
||||
return &extensionMiddleware{Base: transport, Ext: tr, ExtName: p.Name()}
|
||||
}
|
||||
|
||||
531
internal/cmdutil/transport_test.go
Normal file
531
internal/cmdutil/transport_test.go
Normal file
@@ -0,0 +1,531 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
exttransport "github.com/larksuite/cli/extension/transport"
|
||||
internalauth "github.com/larksuite/cli/internal/auth"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RetryTransport
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryTransport_NoRetry(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base, MaxRetries: 0}
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("expected 1 call, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryTransport_RetryOn500(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil
|
||||
}
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond}
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("expected 200 after retries, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryTransport_DefaultNoRetry(t *testing.T) {
|
||||
calls := 0
|
||||
base := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil
|
||||
})
|
||||
rt := &RetryTransport{Base: base} // default MaxRetries=0
|
||||
req, _ := http.NewRequest("GET", "http://example.com/test", nil)
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 500 {
|
||||
t.Errorf("expected 500 with no retries, got %d", resp.StatusCode)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("expected 1 call with default config, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildSDKTransport chain composition
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) {
|
||||
transport := buildSDKTransport()
|
||||
|
||||
// Chain: SecurityPolicy → BuildHeader → UserAgent → Retry → Base
|
||||
sec, ok := transport.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport)
|
||||
}
|
||||
bh, ok := sec.Base.(*BuildHeaderTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after SecurityPolicy = %T, want *BuildHeaderTransport", sec.Base)
|
||||
}
|
||||
ua, ok := bh.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after BuildHeader = %T, want *UserAgentTransport", bh.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSDKTransport_WithExtension(t *testing.T) {
|
||||
exttransport.Register(&stubTransportProvider{})
|
||||
t.Cleanup(func() { exttransport.Register(nil) })
|
||||
|
||||
transport := buildSDKTransport()
|
||||
|
||||
// Chain: extensionMiddleware → SecurityPolicy → BuildHeader → UserAgent → Retry → Base
|
||||
mid, ok := transport.(*extensionMiddleware)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport)
|
||||
}
|
||||
sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base)
|
||||
}
|
||||
bh, ok := sec.Base.(*BuildHeaderTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after SecurityPolicy = %T, want *BuildHeaderTransport", sec.Base)
|
||||
}
|
||||
ua, ok := bh.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after BuildHeader = %T, want *UserAgentTransport", bh.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSDKTransport_WithoutExtension(t *testing.T) {
|
||||
exttransport.Register(nil)
|
||||
|
||||
transport := buildSDKTransport()
|
||||
|
||||
// Chain: SecurityPolicy → BuildHeader → UserAgent → Retry → Base
|
||||
sec, ok := transport.(*internalauth.SecurityPolicyTransport)
|
||||
if !ok {
|
||||
t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport)
|
||||
}
|
||||
bh, ok := sec.Base.(*BuildHeaderTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after SecurityPolicy = %T, want *BuildHeaderTransport", sec.Base)
|
||||
}
|
||||
ua, ok := bh.Base.(*UserAgentTransport)
|
||||
if !ok {
|
||||
t.Fatalf("layer after BuildHeader = %T, want *UserAgentTransport", bh.Base)
|
||||
}
|
||||
if _, ok := ua.Base.(*RetryTransport); !ok {
|
||||
t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionMiddleware — legacy Interceptor path
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type stubTransportProvider struct {
|
||||
interceptor exttransport.Interceptor
|
||||
}
|
||||
|
||||
func (s *stubTransportProvider) Name() string { return "stub" }
|
||||
func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor {
|
||||
if s.interceptor != nil {
|
||||
return s.interceptor
|
||||
}
|
||||
return &stubTransportImpl{}
|
||||
}
|
||||
|
||||
type stubTransportImpl struct{}
|
||||
|
||||
func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerCapturingInterceptor sets custom headers in PreRoundTrip and records
|
||||
// whether PostRoundTrip was called, to verify execution order.
|
||||
type headerCapturingInterceptor struct {
|
||||
preCalled bool
|
||||
postCalled bool
|
||||
}
|
||||
|
||||
func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) {
|
||||
h.preCalled = true
|
||||
// Set a custom header that should survive (no built-in override)
|
||||
req.Header.Set("X-Custom-Trace", "ext-trace-123")
|
||||
// Try to override a security header — should be overwritten by SecurityHeaderTransport
|
||||
req.Header.Set(HeaderSource, "ext-tampered")
|
||||
return func(resp *http.Response, err error) {
|
||||
h.postCalled = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionInterceptor_ExecutionOrder(t *testing.T) {
|
||||
var receivedHeaders http.Header
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ic := &headerCapturingInterceptor{}
|
||||
exttransport.Register(&stubTransportProvider{interceptor: ic})
|
||||
t.Cleanup(func() { exttransport.Register(nil) })
|
||||
|
||||
// Use HTTP transport chain (has SecurityHeaderTransport)
|
||||
var base http.RoundTripper = http.DefaultTransport
|
||||
base = &RetryTransport{Base: base}
|
||||
base = &SecurityHeaderTransport{Base: base}
|
||||
transport := wrapWithExtension(base)
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// PreRoundTrip was called
|
||||
if !ic.preCalled {
|
||||
t.Fatal("PreRoundTrip was not called")
|
||||
}
|
||||
// PostRoundTrip (closure) was called
|
||||
if !ic.postCalled {
|
||||
t.Fatal("PostRoundTrip closure was not called")
|
||||
}
|
||||
// Custom header set by extension survives (no built-in override)
|
||||
if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" {
|
||||
t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123")
|
||||
}
|
||||
// Security header overridden by extension is restored by SecurityHeaderTransport
|
||||
if got := receivedHeaders.Get(HeaderSource); got != SourceValue {
|
||||
t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue)
|
||||
}
|
||||
}
|
||||
|
||||
// buildTamperingInterceptor tries to delete and spoof X-Cli-Build via
|
||||
// PreRoundTrip. The SDK chain's BuildHeaderTransport must restore the real
|
||||
// value before the request leaves the process.
|
||||
type buildTamperingInterceptor struct{}
|
||||
|
||||
func (buildTamperingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) {
|
||||
req.Header.Del(HeaderBuild)
|
||||
req.Header.Set(HeaderBuild, "ext-tampered-build")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestBuildHeaderTransport_SDKChain_OverridesTamperedHeader verifies that the
|
||||
// X-Cli-Build header is force-written by BuildHeaderTransport in the SDK
|
||||
// transport chain, even when an extension tries to delete or spoof it. This
|
||||
// closes the gap where the SDK chain had no equivalent of
|
||||
// SecurityHeaderTransport (see design doc §3.3.3).
|
||||
func TestBuildHeaderTransport_SDKChain_OverridesTamperedHeader(t *testing.T) {
|
||||
var receivedBuild string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBuild = r.Header.Get(HeaderBuild)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exttransport.Register(&stubTransportProvider{interceptor: buildTamperingInterceptor{}})
|
||||
t.Cleanup(func() { exttransport.Register(nil) })
|
||||
|
||||
// Replicate the SDK chain layering used by buildSDKTransport.
|
||||
var base http.RoundTripper = http.DefaultTransport
|
||||
base = &RetryTransport{Base: base}
|
||||
base = &UserAgentTransport{Base: base}
|
||||
base = &BuildHeaderTransport{Base: base}
|
||||
transport := wrapWithExtension(base)
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if receivedBuild == "ext-tampered-build" {
|
||||
t.Fatalf("%s = %q, extension tampering leaked to network", HeaderBuild, receivedBuild)
|
||||
}
|
||||
want := DetectBuildKind()
|
||||
if receivedBuild != want {
|
||||
t.Fatalf("%s = %q, want %q", HeaderBuild, receivedBuild, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildHeaderTransport_OverridesEvenWithoutTamper verifies that even if
|
||||
// no extension is registered, BuildHeaderTransport writes X-Cli-Build.
|
||||
func TestBuildHeaderTransport_OverridesEvenWithoutTamper(t *testing.T) {
|
||||
var receivedBuild string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBuild = r.Header.Get(HeaderBuild)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
transport := &BuildHeaderTransport{Base: http.DefaultTransport}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if receivedBuild == "" {
|
||||
t.Fatalf("%s header missing, BuildHeaderTransport did not inject", HeaderBuild)
|
||||
}
|
||||
want := DetectBuildKind()
|
||||
if receivedBuild != want {
|
||||
t.Fatalf("%s = %q, want %q", HeaderBuild, receivedBuild, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildHeaderTransport_NilBase_UsesFallback verifies that when Base is nil,
|
||||
// the transport still sets X-Cli-Build and routes the request through
|
||||
// util.FallbackTransport rather than panicking. This covers the fallback
|
||||
// branch in RoundTrip that is otherwise unreachable with a non-nil Base.
|
||||
func TestBuildHeaderTransport_NilBase_UsesFallback(t *testing.T) {
|
||||
var receivedBuild string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBuild = r.Header.Get(HeaderBuild)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
transport := &BuildHeaderTransport{Base: nil}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
req, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request via nil-Base transport failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
want := DetectBuildKind()
|
||||
if receivedBuild != want {
|
||||
t.Fatalf("%s = %q, want %q (header must be set even on nil-Base path)",
|
||||
HeaderBuild, receivedBuild, want)
|
||||
}
|
||||
}
|
||||
|
||||
// interceptorFunc adapts a function to exttransport.Interceptor.
|
||||
type interceptorFunc func(*http.Request) func(*http.Response, error)
|
||||
|
||||
func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) }
|
||||
|
||||
func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) {
|
||||
type ctxKeyType string
|
||||
const testKey ctxKeyType = "original"
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
var ctxValue any
|
||||
|
||||
// Use a custom transport that captures the context value seen by the built-in chain
|
||||
capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
ctxValue = req.Context().Value(testKey)
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
})
|
||||
|
||||
// Interceptor that tries to tamper with context
|
||||
tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) {
|
||||
// Try to replace context with a new one
|
||||
*req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered"))
|
||||
return nil
|
||||
})
|
||||
|
||||
mid := &extensionMiddleware{Base: capturer, Ext: tamperIC}
|
||||
|
||||
origCtx := context.WithValue(context.Background(), testKey, "original")
|
||||
req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil)
|
||||
resp, err := mid.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Built-in chain should see original context, not tampered
|
||||
if ctxValue != "original" {
|
||||
t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionMiddleware — PreRoundTripE abort path
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// abortingInterceptor implements exttransport.AbortableInterceptor and
|
||||
// records invocation of the pre and post hooks. These middleware tests only
|
||||
// assert middleware-level integration; pure *AbortError behavior
|
||||
// (Error/Unwrap/Is/As) is covered in extension/transport/errors_test.go.
|
||||
type abortingInterceptor struct {
|
||||
reason error // if non-nil, PreRoundTripE returns this to abort
|
||||
nilPost bool // if true, PreRoundTripE returns a nil post func
|
||||
preECalled bool
|
||||
postCalled bool
|
||||
postResp *http.Response
|
||||
postErr error
|
||||
}
|
||||
|
||||
// PreRoundTrip is a no-op that satisfies the legacy Interceptor method; the
|
||||
// middleware never calls it when PreRoundTripE is present.
|
||||
func (*abortingInterceptor) PreRoundTrip(*http.Request) func(*http.Response, error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *abortingInterceptor) PreRoundTripE(req *http.Request) (func(*http.Response, error), error) {
|
||||
a.preECalled = true
|
||||
if a.nilPost {
|
||||
return nil, a.reason
|
||||
}
|
||||
return func(resp *http.Response, err error) {
|
||||
a.postCalled = true
|
||||
a.postResp = resp
|
||||
a.postErr = err
|
||||
}, a.reason
|
||||
}
|
||||
|
||||
func TestExtensionMiddleware_PreRoundTripEAbort(t *testing.T) {
|
||||
innerErr := errors.New("denied by policy")
|
||||
|
||||
t.Run("skips base and wires AbortError fields", func(t *testing.T) {
|
||||
ic := &abortingInterceptor{reason: innerErr}
|
||||
baseCalls := 0
|
||||
base := roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
baseCalls++
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"}
|
||||
req, _ := http.NewRequest("GET", "http://example.invalid/", nil)
|
||||
resp, err := mid.RoundTrip(req)
|
||||
|
||||
if resp != nil {
|
||||
t.Fatalf("resp = %v, want nil on abort", resp)
|
||||
}
|
||||
if baseCalls != 0 {
|
||||
t.Fatalf("base RoundTrip called %d times on abort, want 0", baseCalls)
|
||||
}
|
||||
if !ic.preECalled {
|
||||
t.Fatal("PreRoundTripE was not called")
|
||||
}
|
||||
|
||||
var aErr *exttransport.AbortError
|
||||
if !errors.As(err, &aErr) {
|
||||
t.Fatalf("errors.As(*AbortError) = false, err = %v (%T)", err, err)
|
||||
}
|
||||
if aErr.Extension != "stub" || aErr.Reason != innerErr {
|
||||
t.Fatalf("AbortError = %+v, want {Extension:stub Reason:%v}", aErr, innerErr)
|
||||
}
|
||||
|
||||
// Post must see the original inner err, not the *AbortError wrapper.
|
||||
if !ic.postCalled {
|
||||
t.Fatal("post hook was not called on abort")
|
||||
}
|
||||
if ic.postResp != nil {
|
||||
t.Fatalf("post resp = %v, want nil", ic.postResp)
|
||||
}
|
||||
if ic.postErr != innerErr {
|
||||
t.Fatalf("post err = %v, want original inner err %v", ic.postErr, innerErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil post still returns AbortError", func(t *testing.T) {
|
||||
ic := &abortingInterceptor{reason: innerErr, nilPost: true}
|
||||
base := roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Fatal("base must not be called on abort")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"}
|
||||
req, _ := http.NewRequest("GET", "http://example.invalid/", nil)
|
||||
_, err := mid.RoundTrip(req)
|
||||
|
||||
var aErr *exttransport.AbortError
|
||||
if !errors.As(err, &aErr) {
|
||||
t.Fatalf("errors.As(*AbortError) = false, err = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtensionMiddleware_PreRoundTripEHappyPath(t *testing.T) {
|
||||
ic := &abortingInterceptor{} // reason == nil → no abort
|
||||
baseCalls := 0
|
||||
base := roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
baseCalls++
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
})
|
||||
|
||||
mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"}
|
||||
req, _ := http.NewRequest("GET", "http://example.invalid/", nil)
|
||||
resp, err := mid.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("happy path returned err: %v", err)
|
||||
}
|
||||
if resp == nil || resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("resp = %v, want 200", resp)
|
||||
}
|
||||
if baseCalls != 1 {
|
||||
t.Fatalf("base RoundTrip called %d times, want 1", baseCalls)
|
||||
}
|
||||
if !ic.preECalled {
|
||||
t.Fatal("PreRoundTripE was not called")
|
||||
}
|
||||
if !ic.postCalled || ic.postErr != nil {
|
||||
t.Fatalf("post hook not called or err != nil: called=%v err=%v", ic.postCalled, ic.postErr)
|
||||
}
|
||||
}
|
||||
@@ -21,11 +21,14 @@ import (
|
||||
|
||||
// DefaultAccountProvider resolves account from config.json via keychain.
|
||||
type DefaultAccountProvider struct {
|
||||
keychain keychain.KeychainAccess
|
||||
keychain func() keychain.KeychainAccess
|
||||
profile string
|
||||
}
|
||||
|
||||
func NewDefaultAccountProvider(kc keychain.KeychainAccess, profile string) *DefaultAccountProvider {
|
||||
func NewDefaultAccountProvider(kc func() keychain.KeychainAccess, profile string) *DefaultAccountProvider {
|
||||
if kc == nil {
|
||||
kc = keychain.Default
|
||||
}
|
||||
return &DefaultAccountProvider{keychain: kc, profile: profile}
|
||||
}
|
||||
|
||||
@@ -36,7 +39,7 @@ func (p *DefaultAccountProvider) ResolveAccount(ctx context.Context) (*Account,
|
||||
return nil, &core.ConfigError{Code: 2, Type: "config", Message: "not configured", Hint: "run `lark-cli config init --new` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete setup."}
|
||||
}
|
||||
|
||||
cfg, err := core.ResolveConfigFromMulti(multi, p.keychain, p.profile)
|
||||
cfg, err := core.ResolveConfigFromMulti(multi, p.keychain(), p.profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
"github.com/larksuite/cli/internal/keychain"
|
||||
)
|
||||
|
||||
type noopKC struct{}
|
||||
@@ -99,7 +100,7 @@ func TestFullChain_ConfigStrictMode(t *testing.T) {
|
||||
}
|
||||
|
||||
ep := &envprovider.Provider{}
|
||||
defaultAcct := credential.NewDefaultAccountProvider(&noopKC{}, "")
|
||||
defaultAcct := credential.NewDefaultAccountProvider(func() keychain.KeychainAccess { return &noopKC{} }, "")
|
||||
|
||||
cp := credential.NewCredentialProvider(
|
||||
[]extcred.Provider{ep},
|
||||
|
||||
@@ -11,4 +11,8 @@ const (
|
||||
CliTenantAccessToken = "LARKSUITE_CLI_TENANT_ACCESS_TOKEN"
|
||||
CliDefaultAs = "LARKSUITE_CLI_DEFAULT_AS"
|
||||
CliStrictMode = "LARKSUITE_CLI_STRICT_MODE"
|
||||
|
||||
// Sidecar proxy (auth proxy mode)
|
||||
CliAuthProxy = "LARKSUITE_CLI_AUTH_PROXY" // sidecar HTTP address, e.g. "http://127.0.0.1:16384"
|
||||
CliProxyKey = "LARKSUITE_CLI_PROXY_KEY" // HMAC signing key shared with sidecar
|
||||
)
|
||||
|
||||
@@ -38,6 +38,17 @@ const (
|
||||
LarkErrDriveResourceContention = 1061045 // resource contention occurred, please retry
|
||||
LarkErrDriveCrossTenantUnit = 1064510 // cross tenant and unit not support
|
||||
LarkErrDriveCrossBrand = 1064511 // cross brand not support
|
||||
|
||||
// Sheets float image: width/height/offset out of range or invalid.
|
||||
LarkErrSheetsFloatImageInvalidDims = 1310246
|
||||
|
||||
// Drive permission apply: per-user-per-document submission limit (5/day) reached.
|
||||
LarkErrDrivePermApplyRateLimit = 1063006
|
||||
// Drive permission apply: request is not applicable for this document
|
||||
// (e.g. the document is configured to disallow access requests, or the
|
||||
// caller already holds the requested permission, or the target type does
|
||||
// not accept apply operations).
|
||||
LarkErrDrivePermApplyNotApplicable = 1063007
|
||||
)
|
||||
|
||||
// ClassifyLarkError maps a Lark API error code + message to (exitCode, errType, hint).
|
||||
@@ -73,6 +84,20 @@ func ClassifyLarkError(code int, msg string) (int, string, string) {
|
||||
return ExitAPI, "cross_tenant_unit", "operate on source and target within the same tenant and region/unit"
|
||||
case LarkErrDriveCrossBrand:
|
||||
return ExitAPI, "cross_brand", "operate on source and target within the same brand environment"
|
||||
|
||||
// sheets-specific constraints that benefit from actionable hints
|
||||
case LarkErrSheetsFloatImageInvalidDims:
|
||||
return ExitAPI, "invalid_params",
|
||||
"check --width / --height / --offset-x / --offset-y: " +
|
||||
"width/height must be >= 20 px; offsets must be >= 0 and less than the anchor cell's width/height"
|
||||
|
||||
// drive permission-apply specific guidance
|
||||
case LarkErrDrivePermApplyRateLimit:
|
||||
return ExitAPI, "rate_limit",
|
||||
"permission-apply quota reached: each user may request access on the same document at most 5 times per day; wait or ask the owner directly"
|
||||
case LarkErrDrivePermApplyNotApplicable:
|
||||
return ExitAPI, "invalid_params",
|
||||
"this document does not accept a permission-apply request (common causes: the document is configured to disallow access requests, the caller already holds the permission, or the target type does not support apply); contact the owner directly"
|
||||
}
|
||||
|
||||
return ExitAPI, "api_error", ""
|
||||
|
||||
@@ -40,6 +40,27 @@ func TestClassifyLarkError_DriveCreateShortcutConstraints(t *testing.T) {
|
||||
wantType: "cross_brand",
|
||||
wantHint: "same brand environment",
|
||||
},
|
||||
{
|
||||
name: "sheets float image invalid dims",
|
||||
code: LarkErrSheetsFloatImageInvalidDims,
|
||||
wantExitCode: ExitAPI,
|
||||
wantType: "invalid_params",
|
||||
wantHint: "--width / --height / --offset-x / --offset-y",
|
||||
},
|
||||
{
|
||||
name: "drive permission apply rate limit",
|
||||
code: LarkErrDrivePermApplyRateLimit,
|
||||
wantExitCode: ExitAPI,
|
||||
wantType: "rate_limit",
|
||||
wantHint: "5 times per day",
|
||||
},
|
||||
{
|
||||
name: "drive permission apply not applicable",
|
||||
code: LarkErrDrivePermApplyNotApplicable,
|
||||
wantExitCode: ExitAPI,
|
||||
wantType: "invalid_params",
|
||||
wantHint: "does not accept a permission-apply request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -61,7 +61,7 @@ func httpClient() *http.Client {
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: fetchTimeout,
|
||||
Transport: util.NewBaseTransport(),
|
||||
Transport: util.SharedTransport(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -72,31 +72,47 @@ func WarnIfProxied(w io.Writer) {
|
||||
})
|
||||
}
|
||||
|
||||
// NewBaseTransport creates an *http.Transport cloned from http.DefaultTransport.
|
||||
// If LARK_CLI_NO_PROXY is set, proxy support is disabled.
|
||||
// Each call returns a new instance; use FallbackTransport for a shared singleton.
|
||||
func NewBaseTransport() *http.Transport {
|
||||
// noProxyTransport is a proxy-disabled clone of http.DefaultTransport,
|
||||
// lazily built the first time LARK_CLI_NO_PROXY is observed set.
|
||||
var noProxyTransport = sync.OnceValue(func() *http.Transport {
|
||||
def, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return &http.Transport{}
|
||||
}
|
||||
t := def.Clone()
|
||||
if os.Getenv(EnvNoProxy) != "" {
|
||||
t.Proxy = nil
|
||||
}
|
||||
t.Proxy = nil
|
||||
return t
|
||||
}
|
||||
|
||||
// fallbackTransport is a lazily-initialized singleton used by transport
|
||||
// decorators when their Base field is nil, preserving connection pooling.
|
||||
var fallbackTransport = sync.OnceValue(func() *http.Transport {
|
||||
return NewBaseTransport()
|
||||
})
|
||||
|
||||
// FallbackTransport returns a shared *http.Transport singleton suitable for
|
||||
// use as a fallback when a transport decorator's Base is nil.
|
||||
// Unlike NewBaseTransport (which clones per call), this reuses a single
|
||||
// instance so that TCP connections and TLS sessions are pooled.
|
||||
func FallbackTransport() *http.Transport {
|
||||
return fallbackTransport()
|
||||
// SharedTransport returns the base http.RoundTripper for CLI HTTP clients.
|
||||
//
|
||||
// By default it returns http.DefaultTransport — the stdlib-provided
|
||||
// process-wide singleton — so every HTTP client in the process shares one
|
||||
// TCP connection pool, TLS session cache, and HTTP/2 state. When
|
||||
// LARK_CLI_NO_PROXY is set it returns a separate proxy-disabled singleton
|
||||
// clone; LARK_CLI_NO_PROXY is checked on every call, but the clone is built
|
||||
// at most once.
|
||||
//
|
||||
// The returned RoundTripper MUST NOT be mutated. Callers that need a
|
||||
// customized transport should assert to *http.Transport and Clone() it.
|
||||
// Using a shared base is required so persistConn readLoop/writeLoop
|
||||
// goroutines are reused; cloning per call leaks them until IdleConnTimeout
|
||||
// (~90s) fires.
|
||||
func SharedTransport() http.RoundTripper {
|
||||
if os.Getenv(EnvNoProxy) != "" {
|
||||
return noProxyTransport()
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// FallbackTransport returns a shared *http.Transport singleton. It is a
|
||||
// thin wrapper over SharedTransport retained so modules that were already
|
||||
// on the leak-free singleton path (internal/auth, internal/cmdutil
|
||||
// transport decorators) do not have to migrate. New code should prefer
|
||||
// SharedTransport and treat the base as an http.RoundTripper.
|
||||
func FallbackTransport() *http.Transport {
|
||||
if t, ok := SharedTransport().(*http.Transport); ok {
|
||||
return t
|
||||
}
|
||||
return noProxyTransport()
|
||||
}
|
||||
|
||||
@@ -28,19 +28,65 @@ func TestDetectProxyEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBaseTransport_Default(t *testing.T) {
|
||||
func TestSharedTransport_DefaultReturnsStdlibSingleton(t *testing.T) {
|
||||
t.Setenv(EnvNoProxy, "")
|
||||
tr := NewBaseTransport()
|
||||
if tr.Proxy == nil {
|
||||
t.Error("expected proxy func to be set when LARK_CLI_NO_PROXY is not set")
|
||||
tr := SharedTransport()
|
||||
if tr != http.DefaultTransport {
|
||||
t.Error("SharedTransport should return http.DefaultTransport when LARK_CLI_NO_PROXY is unset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBaseTransport_NoProxy(t *testing.T) {
|
||||
func TestSharedTransport_NoProxyReturnsClone(t *testing.T) {
|
||||
t.Setenv(EnvNoProxy, "1")
|
||||
tr := NewBaseTransport()
|
||||
if tr.Proxy != nil {
|
||||
t.Error("expected proxy func to be nil when LARK_CLI_NO_PROXY=1")
|
||||
tr := SharedTransport()
|
||||
if tr == http.DefaultTransport {
|
||||
t.Fatal("SharedTransport should return a clone, not DefaultTransport, when LARK_CLI_NO_PROXY is set")
|
||||
}
|
||||
ht, ok := tr.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", tr)
|
||||
}
|
||||
if ht.Proxy != nil {
|
||||
t.Error("no-proxy transport should have Proxy == nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTransport_NoProxyIsCachedSingleton(t *testing.T) {
|
||||
t.Setenv(EnvNoProxy, "1")
|
||||
a := SharedTransport()
|
||||
b := SharedTransport()
|
||||
if a != b {
|
||||
t.Error("repeated SharedTransport calls with LARK_CLI_NO_PROXY set must return the same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTransport_EnvUnsetAfterSetFallsBackToDefault(t *testing.T) {
|
||||
// Simulate a process that first runs with LARK_CLI_NO_PROXY=1 (populating
|
||||
// the no-proxy singleton), then unsets it. Subsequent calls must return
|
||||
// http.DefaultTransport, NOT the cached no-proxy clone.
|
||||
t.Setenv(EnvNoProxy, "1")
|
||||
noProxy := SharedTransport()
|
||||
if noProxy == http.DefaultTransport {
|
||||
t.Fatal("precondition: first call with env set should not return DefaultTransport")
|
||||
}
|
||||
|
||||
t.Setenv(EnvNoProxy, "")
|
||||
after := SharedTransport()
|
||||
if after != http.DefaultTransport {
|
||||
t.Errorf("after unsetting LARK_CLI_NO_PROXY, SharedTransport must return http.DefaultTransport, got %T (%p)", after, after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTransport_NoProxyOverridesSystemProxy(t *testing.T) {
|
||||
t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888")
|
||||
t.Setenv(EnvNoProxy, "1")
|
||||
|
||||
ht, ok := SharedTransport().(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", SharedTransport())
|
||||
}
|
||||
if ht.Proxy != nil {
|
||||
t.Error("LARK_CLI_NO_PROXY should override system proxy settings")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,35 +202,3 @@ func TestWarnIfProxied_RedactsCredentials(t *testing.T) {
|
||||
t.Errorf("warning should contain redacted proxy URL, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBaseTransport_IsHTTPTransport(t *testing.T) {
|
||||
t.Setenv(EnvNoProxy, "")
|
||||
tr := NewBaseTransport()
|
||||
|
||||
// Should be a valid *http.Transport that can be used
|
||||
var rt http.RoundTripper = tr
|
||||
_ = rt
|
||||
|
||||
// Verify it's not the same pointer as DefaultTransport (should be a clone)
|
||||
if tr == http.DefaultTransport {
|
||||
t.Error("NewBaseTransport should return a clone, not DefaultTransport itself")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBaseTransport_RespectsNoProxyEnv(t *testing.T) {
|
||||
// Simulate: user sets both system proxy and our disable flag
|
||||
t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888")
|
||||
t.Setenv(EnvNoProxy, "1")
|
||||
|
||||
tr := NewBaseTransport()
|
||||
if tr.Proxy != nil {
|
||||
t.Error("LARK_CLI_NO_PROXY should override system proxy settings")
|
||||
}
|
||||
|
||||
// Clean up and verify proxy is restored
|
||||
t.Setenv(EnvNoProxy, "")
|
||||
tr2 := NewBaseTransport()
|
||||
if tr2.Proxy == nil {
|
||||
t.Error("proxy should be enabled when LARK_CLI_NO_PROXY is unset")
|
||||
}
|
||||
}
|
||||
|
||||
11
main_authsidecar.go
Normal file
11
main_authsidecar.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build authsidecar
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/larksuite/cli/extension/credential/sidecar" // activate sidecar credential provider
|
||||
_ "github.com/larksuite/cli/extension/transport/sidecar" // activate sidecar transport interceptor
|
||||
)
|
||||
54
main_noauthsidecar.go
Normal file
54
main_noauthsidecar.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build !authsidecar
|
||||
|
||||
// This file is the fail-closed guard for builds that do NOT include the
|
||||
// `authsidecar` tag. The sidecar credential-isolation feature is only
|
||||
// compiled in under that tag; deploying the plain build into an environment
|
||||
// that expects sidecar isolation would silently fall back to direct env
|
||||
// credential use — exactly the failure mode the feature is meant to prevent.
|
||||
//
|
||||
// When LARKSUITE_CLI_AUTH_PROXY is set, we refuse to run rather than ignore
|
||||
// the variable. The operator either rebuilt without realizing (wrong
|
||||
// artifact) or the sandbox inherited the var by accident; both cases want
|
||||
// a loud startup error, not a mysterious token leak on the first API call.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if code := checkNoAuthsidecarBuild(os.Getenv, os.Stderr); code != 0 {
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
|
||||
// checkNoAuthsidecarBuild returns a non-zero exit code (and writes a
|
||||
// human-readable reason to stderr) when the environment asks for sidecar
|
||||
// isolation that this binary cannot provide. Factored out from init() so
|
||||
// tests can exercise the decision without actually calling os.Exit.
|
||||
func checkNoAuthsidecarBuild(getenv func(string) string, stderr io.Writer) int {
|
||||
v := getenv(envvars.CliAuthProxy)
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
fmt.Fprintf(stderr,
|
||||
"ERROR: %s is set, but this lark-cli binary was built WITHOUT the "+
|
||||
"'authsidecar' build tag.\n"+
|
||||
"The sidecar credential-isolation feature is compiled out — "+
|
||||
"running would bypass isolation and\n"+
|
||||
"send any real credentials present in the environment directly "+
|
||||
"to the Lark API.\n\n"+
|
||||
"To fix, either:\n"+
|
||||
" - rebuild the CLI with: go build -tags authsidecar\n"+
|
||||
" - or unset %s if sidecar isolation is not required\n",
|
||||
envvars.CliAuthProxy, envvars.CliAuthProxy)
|
||||
return 2
|
||||
}
|
||||
52
main_noauthsidecar_test.go
Normal file
52
main_noauthsidecar_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build !authsidecar
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
)
|
||||
|
||||
func TestCheckNoAuthsidecarBuild_Unset(t *testing.T) {
|
||||
var stderr bytes.Buffer
|
||||
code := checkNoAuthsidecarBuild(func(string) string { return "" }, &stderr)
|
||||
if code != 0 {
|
||||
t.Errorf("exit code = %d, want 0 when AUTH_PROXY is unset", code)
|
||||
}
|
||||
if stderr.Len() != 0 {
|
||||
t.Errorf("stderr should be empty, got %q", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckNoAuthsidecarBuild_Set verifies that deploying a plain build into
|
||||
// a sandbox that expects sidecar isolation fails loudly at startup instead
|
||||
// of silently leaking credentials through the env provider path.
|
||||
func TestCheckNoAuthsidecarBuild_Set(t *testing.T) {
|
||||
var stderr bytes.Buffer
|
||||
env := func(k string) string {
|
||||
if k == envvars.CliAuthProxy {
|
||||
return "http://127.0.0.1:16384"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
code := checkNoAuthsidecarBuild(env, &stderr)
|
||||
if code == 0 {
|
||||
t.Fatal("expected non-zero exit code when AUTH_PROXY is set")
|
||||
}
|
||||
msg := stderr.String()
|
||||
for _, want := range []string{
|
||||
envvars.CliAuthProxy,
|
||||
"authsidecar", // build-tag name must appear so operators can act on it
|
||||
"rebuild",
|
||||
} {
|
||||
if !strings.Contains(msg, want) {
|
||||
t.Errorf("stderr message missing %q; got:\n%s", want, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@larksuite/cli",
|
||||
"version": "1.0.14",
|
||||
"version": "1.0.17",
|
||||
"description": "The official CLI for Lark/Feishu open platform",
|
||||
"bin": {
|
||||
"lark-cli": "scripts/run.js"
|
||||
|
||||
@@ -38,11 +38,11 @@ const messages = {
|
||||
step3Fail: "应用配置失败。运行以下命令重试: lark-cli config init --new",
|
||||
step4: "授权",
|
||||
step4NotFound: "未找到 lark-cli,跳过授权",
|
||||
step4Confirm: "允许 AI 访问你的飞书数据(消息、文档、日历等)?",
|
||||
step4Confirm: "是否允许 AI 访问你个人的消息、文档、日历等飞书 / Lark 数据,并以你的名义执行操作?",
|
||||
step4Skip: "跳过授权。后续运行 lark-cli auth login 完成授权",
|
||||
step4Done: "授权完成",
|
||||
step4Fail: "授权失败。运行以下命令重试: lark-cli auth login",
|
||||
done: "安装完成!\n现在可以对你的 AI 工具(Claude Code、Trae 等)说:\"Feishu/Lark CLI 能帮我做什么?结合我的情况推荐一下从哪里开始\"",
|
||||
done: "安装完成!\n可以和你的 AI 工具(如 Claude Code、Trae等)说:\"飞书/Lark CLI 能帮我做什么?结合我的情况推荐一下从哪里开始\"",
|
||||
cancelled: "安装已取消",
|
||||
},
|
||||
en: {
|
||||
@@ -66,7 +66,7 @@ const messages = {
|
||||
step3Fail: "Failed to configure app. Run manually: lark-cli config init --new",
|
||||
step4: "Authorization",
|
||||
step4NotFound: "lark-cli not found. Skipping authorization",
|
||||
step4Confirm: "Allow AI to access your Feishu/Lark data (messages, docs, calendar, etc.)?",
|
||||
step4Confirm: "Allow the AI to access your messages, documents, calendar, and more in Feishu/Lark, and perform actions on your behalf?",
|
||||
step4Skip: "Skipped. Run lark-cli auth login to authorize later",
|
||||
step4Done: "Authorization complete",
|
||||
step4Fail: "Failed to authorize. Run lark-cli auth login to retry",
|
||||
|
||||
@@ -137,6 +137,8 @@ func TestDryRunRecordOps(t *testing.T) {
|
||||
"bitable_file",
|
||||
"PATCH /open-apis/base/v3/bases/app_x/tables/tbl_1/records/rec_1",
|
||||
"report-final.pdf",
|
||||
`"mime_type":"\u003cdetected_mime_type\u003e"`,
|
||||
`"size":"\u003cfile_size\u003e"`,
|
||||
"deprecated_set_attachment",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -67,11 +67,15 @@ func runShortcutWithAuthTypes(t *testing.T, shortcut common.Shortcut, authTypes
|
||||
parent.SilenceErrors = true
|
||||
parent.SilenceUsage = true
|
||||
stdout.Reset()
|
||||
if stderr, ok := factory.IOStreams.ErrOut.(*bytes.Buffer); ok {
|
||||
stderr.Reset()
|
||||
}
|
||||
return parent.ExecuteContext(context.Background())
|
||||
}
|
||||
|
||||
func TestBaseWorkspaceExecuteCreate(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactory(t)
|
||||
stderr, _ := factory.IOStreams.ErrOut.(*bytes.Buffer)
|
||||
permStub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/permissions/app_x/members?need_notification=false&type=bitable",
|
||||
@@ -96,6 +100,9 @@ func TestBaseWorkspaceExecuteCreate(t *testing.T) {
|
||||
if data["created"] != true {
|
||||
t.Fatalf("created = %#v, want true", data["created"])
|
||||
}
|
||||
if !strings.Contains(stderr.String(), baseCreateHint) {
|
||||
t.Fatalf("stderr = %q, want %q", stderr.String(), baseCreateHint)
|
||||
}
|
||||
base, _ := data["base"].(map[string]interface{})
|
||||
if got := common.GetString(base, "app_token"); got != "app_x" {
|
||||
t.Fatalf("base.app_token = %q, want %q", got, "app_x")
|
||||
@@ -184,6 +191,7 @@ func TestBaseWorkspaceExecuteGetAndCopy(t *testing.T) {
|
||||
|
||||
func TestBaseWorkspaceExecuteCreateBotAutoGrantSkippedWithoutCurrentUser(t *testing.T) {
|
||||
factory, stdout, reg := newExecuteFactoryWithUserOpenID(t, "")
|
||||
stderr, _ := factory.IOStreams.ErrOut.(*bytes.Buffer)
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/base/v3/bases",
|
||||
@@ -198,6 +206,9 @@ func TestBaseWorkspaceExecuteCreateBotAutoGrantSkippedWithoutCurrentUser(t *test
|
||||
}
|
||||
|
||||
data := decodeBaseEnvelope(t, stdout)
|
||||
if !strings.Contains(stderr.String(), baseCreateHint) {
|
||||
t.Fatalf("stderr = %q, want %q", stderr.String(), baseCreateHint)
|
||||
}
|
||||
grant, _ := data["permission_grant"].(map[string]interface{})
|
||||
if grant["status"] != common.PermissionGrantSkipped {
|
||||
t.Fatalf("permission_grant.status = %#v, want %q", grant["status"], common.PermissionGrantSkipped)
|
||||
@@ -1219,7 +1230,9 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) {
|
||||
!strings.Contains(updateBody, `"image_height":480`) ||
|
||||
!strings.Contains(updateBody, `"deprecated_set_attachment":true`) ||
|
||||
!strings.Contains(updateBody, `"file_token":"file_tok_1"`) ||
|
||||
!strings.Contains(updateBody, `"name":"report.txt"`) {
|
||||
!strings.Contains(updateBody, `"name":"report.txt"`) ||
|
||||
!strings.Contains(updateBody, `"size":16`) ||
|
||||
!strings.Contains(updateBody, `"mime_type":"text/plain"`) {
|
||||
t.Fatalf("update body=%s", updateBody)
|
||||
}
|
||||
})
|
||||
@@ -1370,6 +1383,8 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) {
|
||||
if !strings.Contains(updateBody, `"附件"`) ||
|
||||
!strings.Contains(updateBody, `"file_token":"file_tok_big"`) ||
|
||||
!strings.Contains(updateBody, `"name":"large-report.bin"`) ||
|
||||
!strings.Contains(updateBody, `"size":20971521`) ||
|
||||
!strings.Contains(updateBody, `"mime_type":"application/octet-stream"`) ||
|
||||
!strings.Contains(updateBody, `"deprecated_set_attachment":true`) {
|
||||
t.Fatalf("update body=%s", updateBody)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,14 @@ package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
const baseCreateHint = "Tip: New bases include a default empty table with 5-10 blank records. After finishing table/field setup on this base, ask whether to delete that default table. If yes, run +table-list first, then delete the default table."
|
||||
|
||||
func dryRunBaseGet(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
return common.NewDryRunAPI().
|
||||
GET("/open-apis/base/v3/bases/:base_token").
|
||||
@@ -65,6 +68,7 @@ func executeBaseCreate(runtime *common.RuntimeContext) error {
|
||||
out := map[string]interface{}{"base": data, "created": true}
|
||||
augmentBasePermissionGrant(runtime, out, data)
|
||||
runtime.Out(out, nil)
|
||||
fmt.Fprintln(runtime.IO().ErrOut, baseCreateHint)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ func TestShortcutsCatalog(t *testing.T) {
|
||||
"+table-list", "+table-get", "+table-create", "+table-update", "+table-delete",
|
||||
"+field-list", "+field-get", "+field-create", "+field-update", "+field-delete", "+field-search-options",
|
||||
"+view-list", "+view-get", "+view-create", "+view-delete", "+view-get-filter", "+view-set-filter", "+view-get-visible-fields", "+view-set-visible-fields", "+view-get-group", "+view-set-group", "+view-get-sort", "+view-set-sort", "+view-get-timebar", "+view-set-timebar", "+view-get-card", "+view-set-card", "+view-rename",
|
||||
"+record-list", "+record-search", "+record-get", "+record-upsert", "+record-batch-create", "+record-batch-update", "+record-upload-attachment", "+record-delete",
|
||||
"+record-list", "+record-search", "+record-get", "+record-upsert", "+record-batch-create", "+record-batch-update", "+record-share-link-create", "+record-upload-attachment", "+record-delete",
|
||||
"+record-history-list",
|
||||
"+base-get", "+base-copy", "+base-create",
|
||||
"+role-create", "+role-delete", "+role-update", "+role-list", "+role-get", "+advperm-enable", "+advperm-disable",
|
||||
|
||||
@@ -112,6 +112,56 @@ func dryRunRecordHistoryList(_ context.Context, runtime *common.RuntimeContext)
|
||||
Set("base_token", runtime.Str("base-token"))
|
||||
}
|
||||
|
||||
func dryRunRecordShareBatch(_ context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
recordIDs := deduplicateRecordIDs(runtime)
|
||||
return common.NewDryRunAPI().
|
||||
POST("/open-apis/base/v3/bases/:base_token/tables/:table_id/records/share_links/batch").
|
||||
Body(map[string]interface{}{"record_ids": recordIDs}).
|
||||
Set("base_token", runtime.Str("base-token")).
|
||||
Set("table_id", baseTableID(runtime))
|
||||
}
|
||||
|
||||
const maxShareBatchSize = 100
|
||||
|
||||
func validateRecordShareBatch(runtime *common.RuntimeContext) error {
|
||||
recordIDs := deduplicateRecordIDs(runtime)
|
||||
if len(recordIDs) == 0 {
|
||||
return common.FlagErrorf("--record-ids is required and must not be empty")
|
||||
}
|
||||
if len(recordIDs) > maxShareBatchSize {
|
||||
return common.FlagErrorf("--record-ids exceeds maximum limit of %d (got %d)", maxShareBatchSize, len(recordIDs))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deduplicateRecordIDs(runtime *common.RuntimeContext) []string {
|
||||
raw := runtime.StrSlice("record-ids")
|
||||
seen := make(map[string]bool, len(raw))
|
||||
result := make([]string, 0, len(raw))
|
||||
for _, id := range raw {
|
||||
if id != "" && !seen[id] {
|
||||
seen[id] = true
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func executeRecordShareBatch(runtime *common.RuntimeContext) error {
|
||||
recordIDs := deduplicateRecordIDs(runtime)
|
||||
body := map[string]interface{}{
|
||||
"record_ids": recordIDs,
|
||||
}
|
||||
data, err := baseV3Call(runtime, "POST",
|
||||
baseV3Path("bases", runtime.Str("base-token"), "tables", baseTableID(runtime), "records", "share_links", "batch"),
|
||||
nil, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.Out(data, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRecordJSON(runtime *common.RuntimeContext) error {
|
||||
pc := newParseCtx(runtime)
|
||||
_, err := parseJSONObject(pc, runtime.Str("json"), "json")
|
||||
|
||||
35
shortcuts/base/record_share_link_create.go
Normal file
35
shortcuts/base/record_share_link_create.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
var BaseRecordShareLinkCreate = common.Shortcut{
|
||||
Service: "base",
|
||||
Command: "+record-share-link-create",
|
||||
Description: "Generate share links for one or more records (max 100 per request)",
|
||||
Risk: "read",
|
||||
Scopes: []string{"base:record:read"},
|
||||
AuthTypes: authTypes(),
|
||||
Flags: []common.Flag{
|
||||
baseTokenFlag(true),
|
||||
tableRefFlag(true),
|
||||
{Name: "record-ids", Type: "string_slice", Desc: "record IDs to generate share links for (comma-separated or repeatable, max 100)", Required: true},
|
||||
},
|
||||
Tips: []string{
|
||||
`Single record: --base-token xxx --table-id tblxxx --record-ids recxxx`,
|
||||
`Multiple records: --base-token xxx --table-id tblxxx --record-ids rec001,rec002,rec003`,
|
||||
},
|
||||
Validate: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
return validateRecordShareBatch(runtime)
|
||||
},
|
||||
DryRun: dryRunRecordShareBatch,
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
return executeRecordShareBatch(runtime)
|
||||
},
|
||||
}
|
||||
@@ -4,11 +4,15 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
@@ -105,6 +109,8 @@ func dryRunRecordUploadAttachment(_ context.Context, runtime *common.RuntimeCont
|
||||
map[string]interface{}{
|
||||
"file_token": "<uploaded_file_token>",
|
||||
"name": fileName,
|
||||
"mime_type": "<detected_mime_type>",
|
||||
"size": "<file_size>",
|
||||
"deprecated_set_attachment": true,
|
||||
},
|
||||
},
|
||||
@@ -243,10 +249,14 @@ func normalizeAttachmentForPatch(attachment map[string]interface{}) map[string]i
|
||||
}
|
||||
|
||||
func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName, baseToken string, fileSize int64) (map[string]interface{}, error) {
|
||||
mimeType, err := detectAttachmentMIMEType(runtime.FileIO(), filePath, fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parentNode := baseToken
|
||||
var (
|
||||
fileToken string
|
||||
err error
|
||||
)
|
||||
if fileSize <= common.MaxDriveMediaUploadSinglePartSize {
|
||||
fileToken, err = common.UploadDriveMediaAll(runtime, common.DriveMediaUploadAllConfig{
|
||||
@@ -272,7 +282,78 @@ func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName,
|
||||
attachment := map[string]interface{}{
|
||||
"file_token": fileToken,
|
||||
"name": fileName,
|
||||
"mime_type": mimeType,
|
||||
"size": fileSize,
|
||||
"deprecated_set_attachment": true,
|
||||
}
|
||||
return attachment, nil
|
||||
}
|
||||
|
||||
func detectAttachmentMIMEType(fio fileio.FileIO, filePath, fileName string) (string, error) {
|
||||
if byExt := strings.TrimSpace(mime.TypeByExtension(strings.ToLower(filepath.Ext(fileName)))); byExt != "" {
|
||||
return stripMIMEParams(byExt), nil
|
||||
}
|
||||
if byExt := strings.TrimSpace(mime.TypeByExtension(strings.ToLower(filepath.Ext(filePath)))); byExt != "" {
|
||||
return stripMIMEParams(byExt), nil
|
||||
}
|
||||
|
||||
f, err := fio.Open(filePath)
|
||||
if err != nil {
|
||||
return "", common.WrapInputStatError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
n, readErr := f.Read(buf)
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
return "", output.ErrValidation("cannot read file: %s", readErr)
|
||||
}
|
||||
return detectAttachmentMIMEFromContent(buf[:n]), nil
|
||||
}
|
||||
|
||||
func stripMIMEParams(value string) string {
|
||||
if i := strings.IndexByte(value, ';'); i != -1 {
|
||||
value = value[:i]
|
||||
}
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func detectAttachmentMIMEFromContent(content []byte) string {
|
||||
if len(content) == 0 {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
if bytes.HasPrefix(content, []byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'}) {
|
||||
return "image/png"
|
||||
}
|
||||
if bytes.HasPrefix(content, []byte{0xff, 0xd8, 0xff}) {
|
||||
return "image/jpeg"
|
||||
}
|
||||
if bytes.HasPrefix(content, []byte("GIF87a")) || bytes.HasPrefix(content, []byte("GIF89a")) {
|
||||
return "image/gif"
|
||||
}
|
||||
if len(content) >= 12 && bytes.Equal(content[:4], []byte("RIFF")) && bytes.Equal(content[8:12], []byte("WEBP")) {
|
||||
return "image/webp"
|
||||
}
|
||||
if bytes.HasPrefix(content, []byte("%PDF-")) {
|
||||
return "application/pdf"
|
||||
}
|
||||
if looksLikeText(content) {
|
||||
return "text/plain"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func looksLikeText(content []byte) bool {
|
||||
if !utf8.Valid(content) {
|
||||
return false
|
||||
}
|
||||
for _, r := range string(content) {
|
||||
if r == '\n' || r == '\r' || r == '\t' {
|
||||
continue
|
||||
}
|
||||
if r < 0x20 || r == 0x7f {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
136
shortcuts/base/record_upload_attachment_test.go
Normal file
136
shortcuts/base/record_upload_attachment_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
)
|
||||
|
||||
type attachmentTestFileIO struct {
|
||||
openFile fileio.File
|
||||
openErr error
|
||||
}
|
||||
|
||||
func (f attachmentTestFileIO) Open(string) (fileio.File, error) { return f.openFile, f.openErr }
|
||||
func (attachmentTestFileIO) Stat(string) (fileio.FileInfo, error) {
|
||||
return attachmentTestFileInfo{}, nil
|
||||
}
|
||||
func (attachmentTestFileIO) ResolvePath(path string) (string, error) { return path, nil }
|
||||
func (attachmentTestFileIO) Save(string, fileio.SaveOptions, io.Reader) (fileio.SaveResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type attachmentTestFileInfo struct{}
|
||||
|
||||
func (attachmentTestFileInfo) Size() int64 { return 0 }
|
||||
func (attachmentTestFileInfo) IsDir() bool { return false }
|
||||
func (attachmentTestFileInfo) Mode() fs.FileMode { return 0 }
|
||||
|
||||
type attachmentTestFile struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
func newAttachmentTestFile(content []byte) attachmentTestFile {
|
||||
return attachmentTestFile{Reader: bytes.NewReader(content)}
|
||||
}
|
||||
|
||||
func (attachmentTestFile) Close() error { return nil }
|
||||
|
||||
type attachmentReadErrorFile struct{}
|
||||
|
||||
func (attachmentReadErrorFile) Read([]byte) (int, error) { return 0, os.ErrPermission }
|
||||
func (attachmentReadErrorFile) ReadAt([]byte, int64) (int, error) { return 0, io.EOF }
|
||||
func (attachmentReadErrorFile) Close() error { return nil }
|
||||
|
||||
func TestDetectAttachmentMIMETypeUsesExtension(t *testing.T) {
|
||||
got, err := detectAttachmentMIMEType(nil, "ignored", "note.TXT")
|
||||
if err != nil {
|
||||
t.Fatalf("detectAttachmentMIMEType() error = %v", err)
|
||||
}
|
||||
if got != "text/plain" {
|
||||
t.Fatalf("detectAttachmentMIMEType() = %q, want %q", got, "text/plain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectAttachmentMIMETypeFallsBackToSourcePathExtension(t *testing.T) {
|
||||
got, err := detectAttachmentMIMEType(nil, "report.docx", "report")
|
||||
if err != nil {
|
||||
t.Fatalf("detectAttachmentMIMEType() error = %v", err)
|
||||
}
|
||||
if got != "application/vnd.openxmlformats-officedocument.wordprocessingml.document" {
|
||||
t.Fatalf("detectAttachmentMIMEType() = %q, want docx MIME type", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectAttachmentMIMETypeFallsBackToContent(t *testing.T) {
|
||||
fio := attachmentTestFileIO{openFile: newAttachmentTestFile([]byte("hello from base attachment"))}
|
||||
|
||||
got, err := detectAttachmentMIMEType(fio, "note", "note")
|
||||
if err != nil {
|
||||
t.Fatalf("detectAttachmentMIMEType() error = %v", err)
|
||||
}
|
||||
if got != "text/plain" {
|
||||
t.Fatalf("detectAttachmentMIMEType() = %q, want %q", got, "text/plain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectAttachmentMIMETypeWrapsOpenError(t *testing.T) {
|
||||
fio := attachmentTestFileIO{openErr: os.ErrNotExist}
|
||||
|
||||
_, err := detectAttachmentMIMEType(fio, "missing", "missing")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for open failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot read file") {
|
||||
t.Fatalf("error = %v, want wrapped read failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectAttachmentMIMETypeReturnsReadError(t *testing.T) {
|
||||
fio := attachmentTestFileIO{openFile: attachmentReadErrorFile{}}
|
||||
|
||||
_, err := detectAttachmentMIMEType(fio, "broken", "broken")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for read failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot read file") {
|
||||
t.Fatalf("error = %v, want read failure", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectAttachmentMIMEFromContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
want string
|
||||
}{
|
||||
{name: "empty", content: nil, want: "application/octet-stream"},
|
||||
{name: "png", content: []byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'}, want: "image/png"},
|
||||
{name: "jpeg", content: []byte{0xff, 0xd8, 0xff, 0xe0}, want: "image/jpeg"},
|
||||
{name: "gif87a", content: []byte("GIF87a"), want: "image/gif"},
|
||||
{name: "gif89a", content: []byte("GIF89a"), want: "image/gif"},
|
||||
{name: "webp", content: []byte("RIFF1234WEBP"), want: "image/webp"},
|
||||
{name: "pdf", content: []byte("%PDF-1.7"), want: "application/pdf"},
|
||||
{name: "text", content: []byte("hello from base attachment"), want: "text/plain"},
|
||||
{name: "text with newline", content: []byte("hello\nworld\tok"), want: "text/plain"},
|
||||
{name: "control bytes", content: []byte{'h', 'i', 0x00}, want: "application/octet-stream"},
|
||||
{name: "binary fallback", content: []byte{0x00, 0x01, 0x02, 0x03}, want: "application/octet-stream"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := detectAttachmentMIMEFromContent(tt.content)
|
||||
if got != tt.want {
|
||||
t.Fatalf("detectAttachmentMIMEFromContent() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,7 @@ func Shortcuts() []common.Shortcut {
|
||||
BaseRecordUpsert,
|
||||
BaseRecordBatchCreate,
|
||||
BaseRecordBatchUpdate,
|
||||
BaseRecordShareLinkCreate,
|
||||
BaseRecordUploadAttachment,
|
||||
BaseRecordDelete,
|
||||
BaseRecordHistoryList,
|
||||
|
||||
@@ -54,6 +54,7 @@ func fetchInstanceViewRange(ctx context.Context, runtime *common.RuntimeContext,
|
||||
"start_time": fmt.Sprintf("%d", startTime),
|
||||
"end_time": fmt.Sprintf("%d", endTime),
|
||||
}, nil)
|
||||
err = wrapPredefinedError(err)
|
||||
if err != nil {
|
||||
return nil, output.Errorf(output.ExitAPI, "api_error", "API call failed: %s", err)
|
||||
}
|
||||
|
||||
@@ -194,6 +194,7 @@ var CalendarCreate = common.Shortcut{
|
||||
data, err := runtime.CallAPI("POST",
|
||||
fmt.Sprintf("/open-apis/calendar/v4/calendars/%s/events", validate.EncodePathSegment(calendarId)),
|
||||
nil, eventData)
|
||||
err = wrapPredefinedError(err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -221,11 +222,13 @@ var CalendarCreate = common.Shortcut{
|
||||
"attendees": attendees,
|
||||
"need_notification": true,
|
||||
})
|
||||
err = wrapPredefinedError(err)
|
||||
if err != nil {
|
||||
// Rollback: delete the event
|
||||
_, rollbackErr := runtime.RawAPI("DELETE",
|
||||
fmt.Sprintf("/open-apis/calendar/v4/calendars/%s/events/%s", validate.EncodePathSegment(calendarId), validate.EncodePathSegment(eventId)),
|
||||
map[string]interface{}{"need_notification": false}, nil)
|
||||
rollbackErr = wrapPredefinedError(rollbackErr)
|
||||
if rollbackErr != nil {
|
||||
return output.Errorf(output.ExitAPI, "api_error", "failed to add attendees: %v; rollback also failed, orphan event_id=%s needs manual cleanup", rollbackErr, eventId)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,7 @@ var CalendarFreebusy = common.Shortcut{
|
||||
"user_id": userId,
|
||||
"need_rsvp_status": true,
|
||||
})
|
||||
err = wrapPredefinedError(err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -375,6 +375,238 @@ func TestCreate_NoEventIdReturned(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_CreateEvent_InvalidParamsWithDetail(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"details": []interface{}{
|
||||
map[string]interface{}{"value": "end_time should be later than start_time"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Time",
|
||||
"--start", "2025-03-21T10:00:00+08:00",
|
||||
"--end", "2025-03-21T11:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
if !strings.Contains(exitErr.Detail.Message, "end_time should be later than start_time") {
|
||||
t.Errorf("expected detail value in message, got %q", exitErr.Detail.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_CreateEvent_InvalidParamsWithoutDetailValue(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Time",
|
||||
"--start", "2025-03-21T10:00:00+08:00",
|
||||
"--end", "2025-03-21T11:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_CreateEvent_InvalidParams_ErrorNotMap(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
RawBody: []byte(`{"code":190014,"msg":"invalid params","error":"just a string"}`),
|
||||
ContentType: "text/plain",
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Time",
|
||||
"--start", "2025-03-21T10:00:00+08:00",
|
||||
"--end", "2025-03-21T11:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_CreateEvent_InvalidParams_NoDetailsKey(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"other_key": "no details here",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Time",
|
||||
"--start", "2025-03-21T10:00:00+08:00",
|
||||
"--end", "2025-03-21T11:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_CreateEvent_InvalidParams_DetailItemNotMap(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"details": []interface{}{nil},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Time",
|
||||
"--start", "2025-03-21T10:00:00+08:00",
|
||||
"--end", "2025-03-21T11:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate_WithAttendees_InvalidParamsWithDetail_RollsBack(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/calendars/cal_test123/events",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"event": map[string]interface{}{
|
||||
"event_id": "evt_190014",
|
||||
"summary": "Bad Attendees",
|
||||
"start_time": map[string]interface{}{"timestamp": "1742515200"},
|
||||
"end_time": map[string]interface{}{"timestamp": "1742518800"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/events/evt_190014/attendees",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"details": []interface{}{
|
||||
map[string]interface{}{"value": "invalid attendee open_id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "DELETE",
|
||||
URL: "/events/evt_190014",
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok"},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarCreate, []string{
|
||||
"+create",
|
||||
"--summary", "Bad Attendees",
|
||||
"--start", "2025-03-21T00:00:00+08:00",
|
||||
"--end", "2025-03-21T01:00:00+08:00",
|
||||
"--calendar-id", "cal_test123",
|
||||
"--attendee-ids", "ou_invalid",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid attendees with 190014, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid attendee open_id") {
|
||||
t.Errorf("expected detail value in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CalendarAgenda tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -645,6 +877,67 @@ func TestAgenda_ExplicitCalendarId(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgenda_InvalidParamsWithDetail(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/events/instance_view",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"details": []interface{}{
|
||||
map[string]interface{}{"value": "start_time is required"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarAgenda, []string{
|
||||
"+agenda",
|
||||
"--start", "2025-03-21",
|
||||
"--end", "2025-03-21",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgenda_NonExitError_Passthrough(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/events/instance_view",
|
||||
RawBody: []byte("this is not json"),
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarAgenda, []string{
|
||||
"+agenda",
|
||||
"--start", "2025-03-21",
|
||||
"--end", "2025-03-21",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-JSON response, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if errors.As(err, &exitErr) && exitErr.Detail != nil && exitErr.Detail.Code != 0 {
|
||||
t.Fatalf("expected non-API error passthrough, got API error code %d", exitErr.Detail.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CalendarFreebusy tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -725,6 +1018,46 @@ func TestFreebusy_APIError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFreebusy_InvalidParamsWithDetail(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, defaultConfig())
|
||||
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/calendar/v4/freebusy/list",
|
||||
Body: map[string]interface{}{
|
||||
"code": errCodeInvalidParamsWithDetail,
|
||||
"msg": "invalid params",
|
||||
"error": map[string]interface{}{
|
||||
"details": []interface{}{
|
||||
map[string]interface{}{"value": "user_id is invalid"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRun(t, CalendarFreebusy, []string{
|
||||
"+freebusy",
|
||||
"--start", "2025-03-21",
|
||||
"--end", "2025-03-21",
|
||||
"--user-id", "ou_someone",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 190014, got nil")
|
||||
}
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
t.Fatalf("expected *output.ExitError, got %T", err)
|
||||
}
|
||||
if exitErr.Detail.Code != errCodeInvalidParamsWithDetail {
|
||||
t.Errorf("expected code %d, got %d", errCodeInvalidParamsWithDetail, exitErr.Detail.Code)
|
||||
}
|
||||
if !strings.Contains(exitErr.Detail.Message, "user_id is invalid") {
|
||||
t.Errorf("expected detail value in message, got %q", exitErr.Detail.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CalendarSuggestion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
66
shortcuts/calendar/errors.go
Normal file
66
shortcuts/calendar/errors.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package calendar
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
)
|
||||
|
||||
const (
|
||||
errCodeInvalidParamsWithDetail = 190014
|
||||
)
|
||||
|
||||
// getErrorDetailValue extracts the first detail value from the output.ErrDetail.
|
||||
// It assumes Detail is a map containing a "details" array of objects with "value" string fields.
|
||||
// For example: {"details": [{"value": "error message 1"}, {"value": "error message 2"}]}
|
||||
// Returns an empty string if the structure doesn't match or the array is empty.
|
||||
func getErrorDetailValue(e *output.ErrDetail) string {
|
||||
if e == nil || e.Detail == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
errMap, ok := e.Detail.(map[string]interface{})
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
details, ok := errMap["details"].([]interface{})
|
||||
if !ok || len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
detailObj, ok := details[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
val, _ := detailObj["value"].(string)
|
||||
return val
|
||||
}
|
||||
|
||||
// wrapPredefinedError wraps an error into *output.ExitError if it matches predefined error codes.
|
||||
// Currently handles error code 190014 (invalid params with detail), extracting the detail value into the message.
|
||||
// If the error is nil or doesn't match predefined codes, returns the original error.
|
||||
func wrapPredefinedError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitErr *output.ExitError
|
||||
if !errors.As(err, &exitErr) || exitErr.Detail == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exitErr.Detail.Code == errCodeInvalidParamsWithDetail {
|
||||
if val := getErrorDetailValue(exitErr.Detail); val != "" {
|
||||
fullMsg := fmt.Sprintf("%s: %s", exitErr.Detail.Message, val)
|
||||
return output.ErrAPI(exitErr.Detail.Code, fullMsg, exitErr.Detail.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -37,6 +37,10 @@ type DriveMediaUploadAllConfig struct {
|
||||
ParentType string
|
||||
ParentNode *string
|
||||
Extra string
|
||||
// Reader, when non-nil, is used as the upload source instead of opening
|
||||
// FilePath. Callers must set FileName and FileSize explicitly. The reader
|
||||
// is NOT closed by UploadDriveMediaAll; the caller owns its lifetime.
|
||||
Reader io.Reader
|
||||
}
|
||||
|
||||
type DriveMediaMultipartUploadConfig struct {
|
||||
@@ -49,11 +53,17 @@ type DriveMediaMultipartUploadConfig struct {
|
||||
}
|
||||
|
||||
func UploadDriveMediaAll(runtime *RuntimeContext, cfg DriveMediaUploadAllConfig) (string, error) {
|
||||
f, err := runtime.FileIO().Open(cfg.FilePath)
|
||||
if err != nil {
|
||||
return "", WrapInputStatError(err)
|
||||
var fileReader io.Reader
|
||||
if cfg.Reader != nil {
|
||||
fileReader = cfg.Reader
|
||||
} else {
|
||||
f, err := runtime.FileIO().Open(cfg.FilePath)
|
||||
if err != nil {
|
||||
return "", WrapInputStatError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
fileReader = f
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fd := larkcore.NewFormdata()
|
||||
fd.AddField("file_name", cfg.FileName)
|
||||
@@ -65,7 +75,7 @@ func UploadDriveMediaAll(runtime *RuntimeContext, cfg DriveMediaUploadAllConfig)
|
||||
if cfg.Extra != "" {
|
||||
fd.AddField("extra", cfg.Extra)
|
||||
}
|
||||
fd.AddFile("file", f)
|
||||
fd.AddFile("file", fileReader)
|
||||
|
||||
apiResp, err := runtime.DoAPI(&larkcore.ApiReq{
|
||||
HttpMethod: http.MethodPost,
|
||||
|
||||
@@ -181,6 +181,12 @@ func (ctx *RuntimeContext) StrArray(name string) []string {
|
||||
return v
|
||||
}
|
||||
|
||||
// StrSlice returns a string-slice flag value (supports CSV splitting and repeated flags).
|
||||
func (ctx *RuntimeContext) StrSlice(name string) []string {
|
||||
v, _ := ctx.Cmd.Flags().GetStringSlice(name)
|
||||
return v
|
||||
}
|
||||
|
||||
// ── API helpers ──
|
||||
|
||||
// CallAPI uses an internal HTTP wrapper with limited control over request/response.
|
||||
@@ -571,12 +577,16 @@ func enhancePermissionError(err error, requiredScopes []string) error {
|
||||
|
||||
// Mount registers the shortcut on a parent command.
|
||||
func (s Shortcut) Mount(parent *cobra.Command, f *cmdutil.Factory) {
|
||||
s.MountWithContext(context.Background(), parent, f)
|
||||
}
|
||||
|
||||
func (s Shortcut) MountWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) {
|
||||
if s.Execute != nil {
|
||||
s.mountDeclarative(parent, f)
|
||||
s.mountDeclarative(ctx, parent, f)
|
||||
}
|
||||
}
|
||||
|
||||
func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) {
|
||||
func (s Shortcut) mountDeclarative(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) {
|
||||
shortcut := s
|
||||
if len(shortcut.AuthTypes) == 0 {
|
||||
shortcut.AuthTypes = []string{"user"}
|
||||
@@ -592,7 +602,7 @@ func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) {
|
||||
},
|
||||
}
|
||||
cmdutil.SetSupportedIdentities(cmd, shortcut.AuthTypes)
|
||||
registerShortcutFlags(cmd, &shortcut)
|
||||
registerShortcutFlagsWithContext(ctx, cmd, f, &shortcut)
|
||||
cmdutil.SetTips(cmd, shortcut.Tips)
|
||||
parent.AddCommand(cmd)
|
||||
}
|
||||
@@ -823,7 +833,11 @@ func rejectPositionalArgs() cobra.PositionalArgs {
|
||||
}
|
||||
}
|
||||
|
||||
func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) {
|
||||
func registerShortcutFlags(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) {
|
||||
registerShortcutFlagsWithContext(context.Background(), cmd, f, s)
|
||||
}
|
||||
|
||||
func registerShortcutFlagsWithContext(ctx context.Context, cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) {
|
||||
for _, fl := range s.Flags {
|
||||
desc := fl.Desc
|
||||
if len(fl.Enum) > 0 {
|
||||
@@ -849,6 +863,8 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) {
|
||||
cmd.Flags().Int(fl.Name, d, desc)
|
||||
case "string_array":
|
||||
cmd.Flags().StringArray(fl.Name, nil, desc)
|
||||
case "string_slice":
|
||||
cmd.Flags().StringSlice(fl.Name, nil, desc)
|
||||
default:
|
||||
cmd.Flags().String(fl.Name, fl.Default, desc)
|
||||
}
|
||||
@@ -860,7 +876,7 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) {
|
||||
}
|
||||
if len(fl.Enum) > 0 {
|
||||
vals := fl.Enum
|
||||
_ = cmd.RegisterFlagCompletionFunc(fl.Name, func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, fl.Name, func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return vals, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
}
|
||||
@@ -874,13 +890,9 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) {
|
||||
cmd.Flags().Bool("yes", false, "confirm high-risk operation")
|
||||
}
|
||||
cmd.Flags().StringP("jq", "q", "", "jq expression to filter JSON output")
|
||||
cmd.Flags().String("as", s.AuthTypes[0], "identity type: user | bot")
|
||||
|
||||
_ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return s.AuthTypes, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
cmdutil.AddShortcutIdentityFlag(ctx, cmd, f, s.AuthTypes)
|
||||
if s.HasFormat {
|
||||
_ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
cmdutil.RegisterFlagCompletion(cmd, "format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
|
||||
return []string{"json", "pretty", "table", "ndjson", "csv"}, cobra.ShellCompDirectiveNoFileComp
|
||||
})
|
||||
}
|
||||
|
||||
98
shortcuts/common/runner_flag_completion_test.go
Normal file
98
shortcuts/common/runner_flag_completion_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TestShortcutMount_FlagCompletionsRegistered exercises the two
|
||||
// cmdutil.RegisterFlagCompletion call sites in registerShortcutFlagsWithContext:
|
||||
// the per-flag enum completion (runner.go:879) and the auto-injected --format
|
||||
// completion (runner.go:895).
|
||||
func TestShortcutMount_FlagCompletionsRegistered(t *testing.T) {
|
||||
t.Cleanup(func() { cmdutil.SetFlagCompletionsDisabled(false) })
|
||||
cmdutil.SetFlagCompletionsDisabled(false)
|
||||
|
||||
f, _, _, _ := cmdutil.TestFactory(t, nil)
|
||||
parent := &cobra.Command{Use: "root"}
|
||||
shortcut := Shortcut{
|
||||
Service: "docs",
|
||||
Command: "+fetch",
|
||||
Description: "fetch doc",
|
||||
HasFormat: true,
|
||||
Flags: []Flag{
|
||||
{Name: "sort-by", Desc: "sort", Enum: []string{"asc", "desc"}},
|
||||
},
|
||||
Execute: func(context.Context, *RuntimeContext) error { return nil },
|
||||
}
|
||||
shortcut.Mount(parent, f)
|
||||
|
||||
cmd, _, err := parent.Find([]string{"+fetch"})
|
||||
if err != nil {
|
||||
t.Fatalf("Find() error = %v", err)
|
||||
}
|
||||
|
||||
// Enum flag completion.
|
||||
fn, ok := cmd.GetFlagCompletionFunc("sort-by")
|
||||
if !ok {
|
||||
t.Fatal("expected completion func for --sort-by")
|
||||
}
|
||||
got, _ := fn(cmd, nil, "")
|
||||
if len(got) != 2 || got[0] != "asc" || got[1] != "desc" {
|
||||
t.Fatalf("sort-by completion = %v, want [asc desc]", got)
|
||||
}
|
||||
|
||||
// HasFormat-injected --format completion.
|
||||
fn, ok = cmd.GetFlagCompletionFunc("format")
|
||||
if !ok {
|
||||
t.Fatal("expected completion func for --format")
|
||||
}
|
||||
got, _ = fn(cmd, nil, "")
|
||||
want := []string{"json", "pretty", "table", "ndjson", "csv"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("format completion = %v, want %v", got, want)
|
||||
}
|
||||
for i, v := range want {
|
||||
if got[i] != v {
|
||||
t.Fatalf("format completion[%d] = %q, want %q", i, got[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortcutMount_FlagCompletionsDisabled verifies the switch actually
|
||||
// prevents the two registrations from landing in cobra's global map.
|
||||
func TestShortcutMount_FlagCompletionsDisabled(t *testing.T) {
|
||||
t.Cleanup(func() { cmdutil.SetFlagCompletionsDisabled(false) })
|
||||
cmdutil.SetFlagCompletionsDisabled(true)
|
||||
|
||||
f, _, _, _ := cmdutil.TestFactory(t, nil)
|
||||
parent := &cobra.Command{Use: "root"}
|
||||
shortcut := Shortcut{
|
||||
Service: "docs",
|
||||
Command: "+fetch",
|
||||
Description: "fetch doc",
|
||||
HasFormat: true,
|
||||
Flags: []Flag{
|
||||
{Name: "sort-by", Desc: "sort", Enum: []string{"asc", "desc"}},
|
||||
},
|
||||
Execute: func(context.Context, *RuntimeContext) error { return nil },
|
||||
}
|
||||
shortcut.Mount(parent, f)
|
||||
|
||||
cmd, _, err := parent.Find([]string{"+fetch"})
|
||||
if err != nil {
|
||||
t.Fatalf("Find() error = %v", err)
|
||||
}
|
||||
if _, ok := cmd.GetFlagCompletionFunc("sort-by"); ok {
|
||||
t.Fatal("did not expect completion func for --sort-by when disabled")
|
||||
}
|
||||
if _, ok := cmd.GetFlagCompletionFunc("format"); ok {
|
||||
t.Fatal("did not expect completion func for --format when disabled")
|
||||
}
|
||||
}
|
||||
45
shortcuts/common/runner_identity_flag_test.go
Normal file
45
shortcuts/common/runner_identity_flag_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestShortcutMount_StrictModeHidesAsFlag(t *testing.T) {
|
||||
f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{
|
||||
AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2,
|
||||
})
|
||||
parent := &cobra.Command{Use: "root"}
|
||||
shortcut := Shortcut{
|
||||
Service: "docs",
|
||||
Command: "+fetch",
|
||||
Description: "fetch doc",
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
Execute: func(context.Context, *RuntimeContext) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
shortcut.Mount(parent, f)
|
||||
cmd, _, err := parent.Find([]string{"+fetch"})
|
||||
if err != nil {
|
||||
t.Fatalf("Find() error = %v", err)
|
||||
}
|
||||
flag := cmd.Flags().Lookup("as")
|
||||
if flag == nil {
|
||||
t.Fatal("expected --as flag to be registered")
|
||||
}
|
||||
if !flag.Hidden {
|
||||
t.Fatal("expected --as flag to be hidden in strict mode")
|
||||
}
|
||||
if got := flag.DefValue; got != "bot" {
|
||||
t.Fatalf("default value = %q, want %q", got, "bot")
|
||||
}
|
||||
}
|
||||
@@ -145,10 +145,10 @@ func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newTestShortcutCmd(s *Shortcut) *cobra.Command {
|
||||
func newTestShortcutCmd(s *Shortcut, f *cmdutil.Factory) *cobra.Command {
|
||||
cmd := &cobra.Command{Use: "test-shortcut"}
|
||||
cmd.SetContext(context.Background())
|
||||
registerShortcutFlags(cmd, s)
|
||||
registerShortcutFlags(cmd, f, s)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ func TestRunShortcut_JqAndFormatConflict(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd := newTestShortcutCmd(s)
|
||||
cmd := newTestShortcutCmd(s, newTestFactory())
|
||||
cmd.Flags().Set("jq", ".data")
|
||||
cmd.Flags().Set("format", "table")
|
||||
cmd.Flags().Set("as", "bot")
|
||||
@@ -200,7 +200,7 @@ func TestRunShortcut_JqInvalidExpression(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd := newTestShortcutCmd(s)
|
||||
cmd := newTestShortcutCmd(s, newTestFactory())
|
||||
cmd.Flags().Set("jq", "invalid[")
|
||||
cmd.Flags().Set("as", "bot")
|
||||
|
||||
@@ -223,7 +223,7 @@ func TestRunShortcut_JqRuntimeError_PropagatesError(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cmd := newTestShortcutCmd(s)
|
||||
cmd := newTestShortcutCmd(s, newTestFactory())
|
||||
cmd.Flags().Set("jq", ".foo | invalid_func_xyz")
|
||||
cmd.Flags().Set("as", "bot")
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
@@ -35,7 +36,7 @@ var fileViewMap = map[string]int{
|
||||
var DocMediaInsert = common.Shortcut{
|
||||
Service: "docs",
|
||||
Command: "+media-insert",
|
||||
Description: "Insert a local image or file at the end of a Lark document (4-step orchestration + auto-rollback)",
|
||||
Description: "Insert a local image or file into a Lark document (4-step orchestration + auto-rollback); appends to end by default, or inserts relative to a text selection with --selection-with-ellipsis",
|
||||
Risk: "write",
|
||||
Scopes: []string{"docs:document.media:upload", "docx:document:write_only", "docx:document:readonly"},
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
@@ -45,6 +46,8 @@ var DocMediaInsert = common.Shortcut{
|
||||
{Name: "type", Default: "image", Desc: "type: image | file"},
|
||||
{Name: "align", Desc: "alignment: left | center | right"},
|
||||
{Name: "caption", Desc: "image caption text"},
|
||||
{Name: "selection-with-ellipsis", Desc: "plain text (or 'start...end' to disambiguate) matching the target block's content. Media is inserted at the top-level ancestor of the matched block — i.e., when the selection is inside a callout, table cell, or nested list, media lands outside that container, not inside it. Pass 'start...end' (a unique prefix and suffix separated by '...') when the plain text appears in more than one block"},
|
||||
{Name: "before", Type: "bool", Desc: "insert before the matched block instead of after (requires --selection-with-ellipsis)"},
|
||||
{Name: "file-view", Desc: "file block rendering: card (default) | preview | inline; only applies when --type=file. preview renders audio/video as an inline player"},
|
||||
},
|
||||
Validate: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
@@ -55,6 +58,18 @@ var DocMediaInsert = common.Shortcut{
|
||||
if docRef.Kind == "doc" {
|
||||
return output.ErrValidation("docs +media-insert only supports docx documents; use a docx token/URL or a wiki URL that resolves to docx")
|
||||
}
|
||||
rawSelection := runtime.Str("selection-with-ellipsis")
|
||||
trimmedSelection := strings.TrimSpace(rawSelection)
|
||||
// Explicitly reject a flag that was supplied but blank: runtime.Str cannot
|
||||
// distinguish "omitted" from "provided as empty/whitespace", and a silent
|
||||
// trim-to-empty would make +media-insert fall back to append-mode and
|
||||
// write at the wrong location.
|
||||
if rawSelection != "" && trimmedSelection == "" {
|
||||
return output.ErrValidation("--selection-with-ellipsis must not be blank or whitespace-only")
|
||||
}
|
||||
if runtime.Bool("before") && trimmedSelection == "" {
|
||||
return output.ErrValidation("--before requires --selection-with-ellipsis")
|
||||
}
|
||||
if view := runtime.Str("file-view"); view != "" {
|
||||
if _, ok := fileViewMap[view]; !ok {
|
||||
return output.ErrValidation("invalid --file-view value %q, expected one of: card | preview | inline", view)
|
||||
@@ -76,30 +91,71 @@ var DocMediaInsert = common.Shortcut{
|
||||
filePath := runtime.Str("file")
|
||||
mediaType := runtime.Str("type")
|
||||
caption := runtime.Str("caption")
|
||||
selection := strings.TrimSpace(runtime.Str("selection-with-ellipsis"))
|
||||
hasSelection := selection != ""
|
||||
fileViewType := fileViewMap[runtime.Str("file-view")]
|
||||
|
||||
parentType := parentTypeForMediaType(mediaType)
|
||||
createBlockData := buildCreateBlockData(mediaType, 0, fileViewType)
|
||||
createBlockData["index"] = "<children_len>"
|
||||
if hasSelection {
|
||||
createBlockData["index"] = "<locate_index>"
|
||||
} else {
|
||||
createBlockData["index"] = "<children_len>"
|
||||
}
|
||||
batchUpdateData := buildBatchUpdateData("<new_block_id>", mediaType, "<file_token>", runtime.Str("align"), caption)
|
||||
|
||||
d := common.NewDryRunAPI()
|
||||
totalSteps := 4
|
||||
if docRef.Kind == "wiki" {
|
||||
totalSteps++
|
||||
}
|
||||
if hasSelection {
|
||||
totalSteps++
|
||||
}
|
||||
|
||||
positionLabel := map[bool]string{true: "before", false: "after"}[runtime.Bool("before")]
|
||||
|
||||
if docRef.Kind == "wiki" {
|
||||
documentID = "<resolved_docx_token>"
|
||||
stepBase = 2
|
||||
d.Desc("5-step orchestration: resolve wiki → query root → create block → upload file → bind to block (auto-rollback on failure)").
|
||||
d.Desc(fmt.Sprintf("%d-step orchestration: resolve wiki → query root →%s create block → upload file → bind to block (auto-rollback on failure)",
|
||||
totalSteps, map[bool]string{true: " locate-doc →", false: ""}[hasSelection])).
|
||||
GET("/open-apis/wiki/v2/spaces/get_node").
|
||||
Desc("[1] Resolve wiki node to docx document").
|
||||
Params(map[string]interface{}{"token": docRef.Token})
|
||||
} else {
|
||||
d.Desc("4-step orchestration: query root → create block → upload file → bind to block (auto-rollback on failure)")
|
||||
d.Desc(fmt.Sprintf("%d-step orchestration: query root →%s create block → upload file → bind to block (auto-rollback on failure)",
|
||||
totalSteps, map[bool]string{true: " locate-doc →", false: ""}[hasSelection]))
|
||||
}
|
||||
|
||||
d.
|
||||
GET("/open-apis/docx/v1/documents/:document_id/blocks/:document_id").
|
||||
Desc(fmt.Sprintf("[%d] Get document root block", stepBase)).
|
||||
Desc(fmt.Sprintf("[%d] Get document root block", stepBase))
|
||||
|
||||
if hasSelection {
|
||||
mcpEndpoint := common.MCPEndpoint(runtime.Config.Brand)
|
||||
mcpArgs := map[string]interface{}{
|
||||
"doc_id": documentID,
|
||||
"selection_with_ellipsis": selection,
|
||||
"limit": 1,
|
||||
}
|
||||
d.POST(mcpEndpoint).
|
||||
Desc(fmt.Sprintf("[%d] MCP locate-doc: find block matching selection (%s)", stepBase+1, positionLabel)).
|
||||
Body(map[string]interface{}{
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "locate-doc",
|
||||
"arguments": mcpArgs,
|
||||
},
|
||||
}).
|
||||
Set("mcp_tool", "locate-doc").
|
||||
Set("args", mcpArgs)
|
||||
stepBase++
|
||||
}
|
||||
|
||||
d.
|
||||
POST("/open-apis/docx/v1/documents/:document_id/blocks/:document_id/children").
|
||||
Desc(fmt.Sprintf("[%d] Create empty block at document end", stepBase+1)).
|
||||
Desc(fmt.Sprintf("[%d] Create empty block at target position", stepBase+1)).
|
||||
Body(createBlockData)
|
||||
appendDocMediaInsertUploadDryRun(d, runtime.FileIO(), filePath, parentType, stepBase+2)
|
||||
d.PATCH("/open-apis/docx/v1/documents/:document_id/blocks/batch_update").
|
||||
@@ -144,13 +200,31 @@ var DocMediaInsert = common.Shortcut{
|
||||
return err
|
||||
}
|
||||
|
||||
parentBlockID, insertIndex, err := extractAppendTarget(rootData, documentID)
|
||||
parentBlockID, insertIndex, rootChildren, err := extractAppendTarget(rootData, documentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Root block ready: %s (%d children)\n", parentBlockID, insertIndex)
|
||||
|
||||
// Step 2: Create an empty block at the end of the document
|
||||
selection := strings.TrimSpace(runtime.Str("selection-with-ellipsis"))
|
||||
if selection != "" {
|
||||
before := runtime.Bool("before")
|
||||
// Redact the selection when logging — it is copied verbatim from
|
||||
// document content and may contain confidential text.
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Locating block matching selection (%s)\n", redactSelection(selection))
|
||||
idx, err := locateInsertIndex(runtime, documentID, selection, rootChildren, before)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
insertIndex = idx
|
||||
posLabel := "after"
|
||||
if before {
|
||||
posLabel = "before"
|
||||
}
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "locate-doc matched: inserting %s at index %d\n", posLabel, insertIndex)
|
||||
}
|
||||
|
||||
// Step 2: Create an empty block at the target position
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Creating block at index %d\n", insertIndex)
|
||||
|
||||
createData, err := runtime.CallAPI("POST",
|
||||
@@ -224,6 +298,20 @@ func blockTypeForMediaType(mediaType string) int {
|
||||
return 27
|
||||
}
|
||||
|
||||
// redactSelection summarizes --selection-with-ellipsis values for logging and
|
||||
// error messages without echoing raw document text. Returns the rune count and,
|
||||
// for longer strings, a short prefix so operators can still identify which
|
||||
// selection failed without leaking confidential content into terminals or CI
|
||||
// logs.
|
||||
func redactSelection(s string) string {
|
||||
const prefixRunes = 8
|
||||
runes := []rune(s)
|
||||
if len(runes) <= prefixRunes {
|
||||
return fmt.Sprintf("%d chars", len(runes))
|
||||
}
|
||||
return fmt.Sprintf("%q… %d chars total", string(runes[:prefixRunes]), len(runes))
|
||||
}
|
||||
|
||||
func parentTypeForMediaType(mediaType string) string {
|
||||
if mediaType == "file" {
|
||||
return "docx_file"
|
||||
@@ -332,19 +420,150 @@ func buildBatchUpdateData(blockID, mediaType, fileToken, alignStr, caption strin
|
||||
}
|
||||
}
|
||||
|
||||
func extractAppendTarget(rootData map[string]interface{}, fallbackBlockID string) (string, int, error) {
|
||||
func extractAppendTarget(rootData map[string]interface{}, fallbackBlockID string) (parentBlockID string, insertIndex int, children []interface{}, err error) {
|
||||
block, _ := rootData["block"].(map[string]interface{})
|
||||
if len(block) == 0 {
|
||||
return "", 0, output.Errorf(output.ExitAPI, "api_error", "failed to query document root block")
|
||||
return "", 0, nil, output.Errorf(output.ExitAPI, "api_error", "failed to query document root block")
|
||||
}
|
||||
|
||||
parentBlockID := fallbackBlockID
|
||||
parentBlockID = fallbackBlockID
|
||||
if blockID, _ := block["block_id"].(string); blockID != "" {
|
||||
parentBlockID = blockID
|
||||
}
|
||||
|
||||
children, _ := block["children"].([]interface{})
|
||||
return parentBlockID, len(children), nil
|
||||
children, _ = block["children"].([]interface{})
|
||||
return parentBlockID, len(children), children, nil
|
||||
}
|
||||
|
||||
// locateInsertIndex uses the MCP locate-doc tool to find the root-level index
|
||||
// at which to insert relative to the block matching selection. It walks the
|
||||
// parent_id chain (using single-block GET calls when needed) to resolve nested
|
||||
// blocks to their top-level ancestor in rootChildren.
|
||||
func locateInsertIndex(runtime *common.RuntimeContext, documentID string, selection string, rootChildren []interface{}, before bool) (int, error) {
|
||||
// Ask for 2 matches so we can warn when the selection is ambiguous. locate-doc
|
||||
// orders matches by document position, so matches[0] is still deterministic.
|
||||
args := map[string]interface{}{
|
||||
"doc_id": documentID,
|
||||
"selection_with_ellipsis": selection,
|
||||
"limit": 2,
|
||||
}
|
||||
result, err := common.CallMCPTool(runtime, "locate-doc", args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
matches := common.GetSlice(result, "matches")
|
||||
if len(matches) == 0 {
|
||||
return 0, output.ErrWithHint(
|
||||
output.ExitValidation,
|
||||
"no_match",
|
||||
fmt.Sprintf("locate-doc did not find any block matching selection (%s)", redactSelection(selection)),
|
||||
"check spelling or use 'start...end' syntax to narrow the selection",
|
||||
)
|
||||
}
|
||||
if len(matches) > 1 {
|
||||
// Silently picking the first match surprises users whose selection appears
|
||||
// in more than one block (e.g. the same phrase in a title and a paragraph).
|
||||
// Surface that another match exists and point at the 'start...end' disambiguator.
|
||||
fmt.Fprintf(runtime.IO().ErrOut,
|
||||
"warning: selection (%s) matched more than one block; inserting relative to the first. "+
|
||||
"Pass --selection-with-ellipsis 'start...end' to narrow.\n",
|
||||
redactSelection(selection))
|
||||
}
|
||||
|
||||
matchMap, _ := matches[0].(map[string]interface{})
|
||||
anchorBlockID := common.GetString(matchMap, "anchor_block_id")
|
||||
if anchorBlockID == "" {
|
||||
// Fall back to first block entry if anchor_block_id is absent.
|
||||
blocks := common.GetSlice(matchMap, "blocks")
|
||||
if len(blocks) > 0 {
|
||||
if b, ok := blocks[0].(map[string]interface{}); ok {
|
||||
anchorBlockID = common.GetString(b, "block_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
if anchorBlockID == "" {
|
||||
return 0, output.Errorf(output.ExitAPI, "api_error", "locate-doc response missing anchor_block_id")
|
||||
}
|
||||
parentBlockID := common.GetString(matchMap, "parent_block_id")
|
||||
|
||||
// Build root children set for O(1) lookup.
|
||||
rootSet := make(map[string]int, len(rootChildren))
|
||||
for i, c := range rootChildren {
|
||||
if id, ok := c.(string); ok {
|
||||
rootSet[id] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Walk up the parent chain to the top-level ancestor in rootChildren. This
|
||||
// is serial by nature: each level's parent_id is only known after the
|
||||
// previous level's GET /blocks/{id} response arrives, so the calls cannot
|
||||
// be batched or parallelised.
|
||||
//
|
||||
// visited is the real cycle guard — it stops an A→B→A parent-id loop (seen
|
||||
// on malformed API responses) after one lap. maxDepth is belt-and-suspenders
|
||||
// in case both visited tracking and parent_id sanity simultaneously break;
|
||||
// 32 comfortably exceeds the deepest real docx nesting (~6–8 levels for
|
||||
// quote/callout/list combinations) without letting a bug run unbounded.
|
||||
cur := anchorBlockID
|
||||
nextParent := parentBlockID
|
||||
visited := map[string]bool{}
|
||||
const maxDepth = 32
|
||||
walkDepth := 0
|
||||
for depth := 0; depth < maxDepth; depth++ {
|
||||
if visited[cur] {
|
||||
break
|
||||
}
|
||||
visited[cur] = true
|
||||
|
||||
if idx, ok := rootSet[cur]; ok {
|
||||
if walkDepth > 0 {
|
||||
// The anchor was nested inside a callout / table cell / list and
|
||||
// got resolved to its top-level ancestor. Surface this so users
|
||||
// don't misread "insert before 'X'" as "insert right next to X"
|
||||
// when X is buried several levels deep.
|
||||
posLabel := "after"
|
||||
if before {
|
||||
posLabel = "before"
|
||||
}
|
||||
fmt.Fprintf(runtime.IO().ErrOut,
|
||||
"note: selection (%s) was nested %d level(s) deep; inserting %s its top-level ancestor at index %d\n",
|
||||
redactSelection(selection), walkDepth, posLabel, idx)
|
||||
}
|
||||
if before {
|
||||
return idx, nil
|
||||
}
|
||||
return idx + 1, nil
|
||||
}
|
||||
|
||||
// Advance: use the parent hint we already have, or fetch from API.
|
||||
parent := nextParent
|
||||
nextParent = "" // clear hint after first use
|
||||
if parent == "" || parent == cur {
|
||||
// Need to fetch this block to find its parent.
|
||||
data, err := runtime.CallAPI("GET",
|
||||
fmt.Sprintf("/open-apis/docx/v1/documents/%s/blocks/%s",
|
||||
validate.EncodePathSegment(documentID), validate.EncodePathSegment(cur)),
|
||||
nil, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
block := common.GetMap(data, "block")
|
||||
parent = common.GetString(block, "parent_id")
|
||||
}
|
||||
if parent == "" || parent == cur {
|
||||
break
|
||||
}
|
||||
cur = parent
|
||||
walkDepth++
|
||||
}
|
||||
|
||||
return 0, output.ErrWithHint(
|
||||
output.ExitValidation,
|
||||
"block_not_reachable",
|
||||
fmt.Sprintf("block matching selection (%s) is not reachable from document root", redactSelection(selection)),
|
||||
"try a top-level heading or paragraph as the selection",
|
||||
)
|
||||
}
|
||||
|
||||
func extractCreatedBlockTargets(createData map[string]interface{}, mediaType string) (blockID, uploadParentNode, replaceBlockID string) {
|
||||
|
||||
@@ -5,12 +5,15 @@ package doc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
@@ -222,7 +225,7 @@ func TestExtractAppendTargetUsesRootChildrenCount(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
blockID, index, err := extractAppendTarget(rootData, "fallback")
|
||||
blockID, index, children, err := extractAppendTarget(rootData, "fallback")
|
||||
if err != nil {
|
||||
t.Fatalf("extractAppendTarget() unexpected error: %v", err)
|
||||
}
|
||||
@@ -232,6 +235,365 @@ func TestExtractAppendTargetUsesRootChildrenCount(t *testing.T) {
|
||||
if index != 3 {
|
||||
t.Fatalf("extractAppendTarget() index = %d, want 3", index)
|
||||
}
|
||||
if len(children) != 3 {
|
||||
t.Fatalf("extractAppendTarget() children len = %d, want 3", len(children))
|
||||
}
|
||||
}
|
||||
|
||||
// buildLocateDocMCPResponse builds a JSON-RPC 2.0 response for a locate-doc MCP call.
|
||||
func buildLocateDocMCPResponse(matches []map[string]interface{}) map[string]interface{} {
|
||||
resultJSON, _ := json.Marshal(map[string]interface{}{"matches": matches})
|
||||
return map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": map[string]interface{}{
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": string(resultJSON),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// registerInsertWithSelectionStubs wires the minimal stub set for the
|
||||
// --selection-with-ellipsis happy path. Returns the create-block stub so
|
||||
// callers can inspect the request body (e.g. to verify the computed index).
|
||||
func registerInsertWithSelectionStubs(reg interface {
|
||||
Register(*httpmock.Stub)
|
||||
}, docID, anchorBlockID, parentBlockID string, rootChildren []interface{}) *httpmock.Stub {
|
||||
// Root block
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": rootChildren,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
// MCP locate-doc
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{
|
||||
{"anchor_block_id": anchorBlockID, "parent_block_id": parentBlockID},
|
||||
}),
|
||||
})
|
||||
// Create block — returned so the test can inspect index in CapturedBody.
|
||||
createStub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID + "/children",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"children": []interface{}{
|
||||
map[string]interface{}{"block_id": "blk_new", "block_type": 27, "image": map[string]interface{}{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(createStub)
|
||||
// Upload
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/medias/upload_all",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{"file_token": "ftok_test"},
|
||||
},
|
||||
})
|
||||
// Batch update
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "PATCH",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/batch_update",
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}},
|
||||
})
|
||||
return createStub
|
||||
}
|
||||
|
||||
// assertCreateBlockIndex decodes the create-block request body and asserts the
|
||||
// `index` field equals want. Fails the test if the body is missing or wrong.
|
||||
func assertCreateBlockIndex(t *testing.T, stub *httpmock.Stub, want int) {
|
||||
t.Helper()
|
||||
if stub.CapturedBody == nil {
|
||||
t.Fatalf("create-block stub captured no body")
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(stub.CapturedBody, &body); err != nil {
|
||||
t.Fatalf("decode create-block body: %v (raw: %s)", err, stub.CapturedBody)
|
||||
}
|
||||
got, _ := body["index"].(float64)
|
||||
if int(got) != want {
|
||||
t.Fatalf("create-block index = %v, want %d (body: %s)", body["index"], want, stub.CapturedBody)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexAfterModeViaExecute verifies that
|
||||
// --selection-with-ellipsis (default after-mode) places the new block
|
||||
// immediately after the matched root-level block. Uses three root children so
|
||||
// the after-index (2) differs from what --before would produce (1), and
|
||||
// inspects the create-block request body to prove the computed index actually
|
||||
// reaches the /children API.
|
||||
func TestLocateInsertIndexAfterModeViaExecute(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-after-app"))
|
||||
createStub := registerInsertWithSelectionStubs(reg, "doxcnSEL", "blk_b", "doxcnSEL",
|
||||
[]interface{}{"blk_a", "blk_b", "blk_c"})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", "doxcnSEL",
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "Introduction",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error: %v", err)
|
||||
}
|
||||
// after blk_b (index 1) → insert at index 2, between blk_b and blk_c.
|
||||
assertCreateBlockIndex(t, createStub, 2)
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexBeforeModeViaExecute verifies that --before inserts
|
||||
// before the matched root-level block. Pairs with the after-mode test above:
|
||||
// same fixture, same anchor, but --before should flip the index from 2 to 1.
|
||||
// A regression that ignored --before would still pass the success check alone,
|
||||
// so we assert the create-block body explicitly.
|
||||
func TestLocateInsertIndexBeforeModeViaExecute(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-before-app"))
|
||||
createStub := registerInsertWithSelectionStubs(reg, "doxcnSEL2", "blk_b", "doxcnSEL2",
|
||||
[]interface{}{"blk_a", "blk_b", "blk_c"})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", "doxcnSEL2",
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "Architecture",
|
||||
"--before",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error: %v", err)
|
||||
}
|
||||
// before blk_b (index 1) → insert at index 1, between blk_a and blk_b.
|
||||
assertCreateBlockIndex(t, createStub, 1)
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexNestedBlockViaExecute verifies that a deeply-nested
|
||||
// anchor (2+ levels below root) walks up through an intermediate block via
|
||||
// the GET /blocks/{id} API to find the root-level ancestor. This exercises
|
||||
// the fallback ancestor-walk path in locateInsertIndex — the parent_block_id
|
||||
// hint from locate-doc is only good for one level, so deeper nesting must hit
|
||||
// the block-fetch loop.
|
||||
func TestLocateInsertIndexNestedBlockViaExecute(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-nested-app"))
|
||||
|
||||
docID := "doxcnNESTED"
|
||||
// Root children: blk_section (index 0), blk_other (index 1).
|
||||
// Anchor blk_grandchild is nested two levels deep:
|
||||
// root → blk_section → blk_section_child → blk_grandchild
|
||||
// locate-doc gives us parent_block_id = blk_section_child (one level up);
|
||||
// the walk must fetch blk_section_child to discover its parent = blk_section
|
||||
// before it can land on a root child.
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": []interface{}{"blk_section", "blk_other"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{
|
||||
{"anchor_block_id": "blk_grandchild", "parent_block_id": "blk_section_child"},
|
||||
}),
|
||||
})
|
||||
// Intermediate block lookup — this is the key step that exercises the
|
||||
// fallback walk. Without this stub the test would fail.
|
||||
intermediateStub := &httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/blk_section_child",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": "blk_section_child",
|
||||
"parent_id": "blk_section",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(intermediateStub)
|
||||
createStub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID + "/children",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"children": []interface{}{
|
||||
map[string]interface{}{"block_id": "blk_new", "block_type": 27, "image": map[string]interface{}{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(createStub)
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/medias/upload_all",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{"file_token": "ftok_nested"},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "PATCH",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/batch_update",
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}},
|
||||
})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", docID,
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "nested content",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error: %v", err)
|
||||
}
|
||||
// Confirm the ancestor-walk actually fired — without this assertion a
|
||||
// regression that short-circuited the walk would still pass.
|
||||
if intermediateStub.CapturedBody == nil && intermediateStub.CapturedHeaders == nil {
|
||||
t.Errorf("expected GET /blocks/blk_section_child to be invoked by the parent-walk; stub was not hit")
|
||||
}
|
||||
// after blk_section (index 0) → insert at index 1, between blk_section and blk_other.
|
||||
assertCreateBlockIndex(t, createStub, 1)
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexNoMatchReturnsError verifies that when locate-doc returns
|
||||
// no matches, Execute returns a descriptive error.
|
||||
func TestLocateInsertIndexNoMatchReturnsError(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-nomatch-app"))
|
||||
|
||||
docID := "doxcnNOMATCH"
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": []interface{}{"blk_a"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{}),
|
||||
})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", docID,
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "nonexistent text",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected no-match error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no_match") && !strings.Contains(err.Error(), "did not find") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexDryRunIncludesMCPStep verifies that the dry-run output
|
||||
// includes a locate-doc MCP step when --selection-with-ellipsis is provided.
|
||||
func TestLocateInsertIndexDryRunIncludesMCPStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := &cobra.Command{Use: "docs +media-insert"}
|
||||
cmd.Flags().String("file", "", "")
|
||||
cmd.Flags().String("doc", "", "")
|
||||
cmd.Flags().String("type", "image", "")
|
||||
cmd.Flags().String("align", "", "")
|
||||
cmd.Flags().String("caption", "", "")
|
||||
cmd.Flags().String("selection-with-ellipsis", "", "")
|
||||
cmd.Flags().Bool("before", false, "")
|
||||
_ = cmd.Flags().Set("file", "img.png")
|
||||
_ = cmd.Flags().Set("doc", "doxcnABCDEF")
|
||||
_ = cmd.Flags().Set("selection-with-ellipsis", "Introduction")
|
||||
|
||||
rt := common.TestNewRuntimeContext(cmd, docsTestConfigWithAppID("dry-run-app"))
|
||||
dryAPI := DocMediaInsert.DryRun(context.Background(), rt)
|
||||
raw, _ := json.Marshal(dryAPI)
|
||||
|
||||
var dry struct {
|
||||
Description string `json:"description"`
|
||||
API []struct {
|
||||
Desc string `json:"desc"`
|
||||
URL string `json:"url"`
|
||||
Body map[string]interface{} `json:"body"`
|
||||
} `json:"api"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &dry); err != nil {
|
||||
t.Fatalf("decode dry-run: %v", err)
|
||||
}
|
||||
|
||||
foundMCP := false
|
||||
for _, step := range dry.API {
|
||||
if strings.Contains(step.Desc, "locate-doc") {
|
||||
foundMCP = true
|
||||
}
|
||||
}
|
||||
if !foundMCP {
|
||||
t.Fatalf("dry-run should include a locate-doc step, got: %+v", dry.API)
|
||||
}
|
||||
if !strings.Contains(dry.Description, "locate-doc") {
|
||||
t.Fatalf("dry-run description should mention 'locate-doc', got: %s", dry.Description)
|
||||
}
|
||||
|
||||
// Verify create-block step shows <locate_index> not <children_len>
|
||||
for _, step := range dry.API {
|
||||
if strings.Contains(step.URL, "/children") && step.Body != nil {
|
||||
if idx, ok := step.Body["index"]; ok {
|
||||
if idx != "<locate_index>" {
|
||||
t.Fatalf("create-block index in selection mode = %q, want <locate_index>", idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCreatedBlockTargetsForImage(t *testing.T) {
|
||||
@@ -369,3 +731,256 @@ func TestDocMediaInsertValidateFileView(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexWarnsOnMultipleMatches verifies that when locate-doc
|
||||
// returns more than one match, a warning is written to stderr pointing the user
|
||||
// at the 'start...end' disambiguation syntax. Silently picking the first match
|
||||
// of an ambiguous selection is a real UX trap — users who edit documents with
|
||||
// repeated phrases (a heading that also appears in the TOC, for example) get
|
||||
// no signal that another match existed.
|
||||
func TestLocateInsertIndexWarnsOnMultipleMatches(t *testing.T) {
|
||||
f, _, stderr, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-multi-app"))
|
||||
|
||||
docID := "doxcnMULTI"
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": []interface{}{"blk_a", "blk_b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
// Two matches — same selection appears in two different root-level blocks.
|
||||
// locate-doc orders matches by document position, so matches[0] is still
|
||||
// deterministic (blk_a) even with limit=2.
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{
|
||||
{"anchor_block_id": "blk_a", "parent_block_id": docID},
|
||||
{"anchor_block_id": "blk_b", "parent_block_id": docID},
|
||||
}),
|
||||
})
|
||||
createStub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID + "/children",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"children": []interface{}{
|
||||
map[string]interface{}{"block_id": "blk_new", "block_type": 27, "image": map[string]interface{}{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(createStub)
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/medias/upload_all",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{"file_token": "ftok_multi"},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "PATCH",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/batch_update",
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}},
|
||||
})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", docID,
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "Repeated phrase",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error: %v", err)
|
||||
}
|
||||
|
||||
// Warning should name the ambiguity and point at 'start...end'.
|
||||
stderrOut := stderr.String()
|
||||
if !strings.Contains(stderrOut, "matched more than one block") {
|
||||
t.Errorf("stderr missing multi-match warning; got:\n%s", stderrOut)
|
||||
}
|
||||
if !strings.Contains(stderrOut, "start...end") {
|
||||
t.Errorf("stderr missing 'start...end' disambiguation hint; got:\n%s", stderrOut)
|
||||
}
|
||||
// Should still insert at the first match (blk_a at index 0) → after ⇒ 1.
|
||||
assertCreateBlockIndex(t, createStub, 1)
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexLogsNestedAnchor verifies that when the matched block is
|
||||
// nested (not a direct root child), a note is written to stderr explaining that
|
||||
// the media lands at the top-level ancestor. This protects users from being
|
||||
// surprised when selecting text inside a callout or table cell and seeing the
|
||||
// image appear outside that container.
|
||||
func TestLocateInsertIndexLogsNestedAnchor(t *testing.T) {
|
||||
f, _, stderr, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-nested-log-app"))
|
||||
|
||||
docID := "doxcnNESTEDLOG"
|
||||
// Same shape as TestLocateInsertIndexNestedBlockViaExecute: anchor is two
|
||||
// levels below root, so walkDepth == 2 when we hit the root ancestor.
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": []interface{}{"blk_section", "blk_other"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{
|
||||
{"anchor_block_id": "blk_grandchild", "parent_block_id": "blk_section_child"},
|
||||
}),
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/blk_section_child",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": "blk_section_child",
|
||||
"parent_id": "blk_section",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID + "/children",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"children": []interface{}{
|
||||
map[string]interface{}{"block_id": "blk_new", "block_type": 27, "image": map[string]interface{}{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/medias/upload_all",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{"file_token": "ftok_nested_log"},
|
||||
},
|
||||
})
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "PATCH",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/batch_update",
|
||||
Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}},
|
||||
})
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", docID,
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "nested content",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error: %v", err)
|
||||
}
|
||||
|
||||
stderrOut := stderr.String()
|
||||
if !strings.Contains(stderrOut, "nested") || !strings.Contains(stderrOut, "top-level ancestor") {
|
||||
t.Errorf("stderr missing nested-anchor note; got:\n%s", stderrOut)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocateInsertIndexCycleDetection verifies that a malformed parent chain
|
||||
// (blk_x.parent = blk_y and blk_y.parent = blk_x, neither reachable from root)
|
||||
// does not spin the locate-doc walk forever. The `visited` map must break the
|
||||
// cycle, and the user must see the "not reachable from document root" error
|
||||
// rather than the process hanging. Without this test, a regression that broke
|
||||
// cycle protection would only surface in production with a stalled CLI.
|
||||
func TestLocateInsertIndexCycleDetection(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("locate-cycle-app"))
|
||||
|
||||
docID := "doxcnCYCLE"
|
||||
// Root has unrelated children — neither blk_x nor blk_y reach root.
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/" + docID,
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": docID,
|
||||
"children": []interface{}{"blk_unrelated_a", "blk_unrelated_b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
// locate-doc hints parent_block_id = blk_y for anchor blk_x (first hop consumed).
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "mcp.feishu.cn/mcp",
|
||||
Body: buildLocateDocMCPResponse([]map[string]interface{}{
|
||||
{"anchor_block_id": "blk_x", "parent_block_id": "blk_y"},
|
||||
}),
|
||||
})
|
||||
// blk_y claims blk_x as parent — closes the cycle. The walk must land here
|
||||
// exactly once before visited[blk_x] triggers a break.
|
||||
blkYStub := &httpmock.Stub{
|
||||
Method: "GET",
|
||||
URL: "/open-apis/docx/v1/documents/" + docID + "/blocks/blk_y",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "ok",
|
||||
"data": map[string]interface{}{
|
||||
"block": map[string]interface{}{
|
||||
"block_id": "blk_y",
|
||||
"parent_id": "blk_x",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg.Register(blkYStub)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
withDocsWorkingDir(t, tmpDir)
|
||||
writeSizedDocTestFile(t, "img.png", 100)
|
||||
|
||||
err := mountAndRunDocs(t, DocMediaInsert, []string{
|
||||
"+media-insert",
|
||||
"--doc", docID,
|
||||
"--file", "img.png",
|
||||
"--selection-with-ellipsis", "cyclic anchor",
|
||||
"--as", "bot",
|
||||
}, f, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected 'block_not_reachable' error from cyclic parent chain; got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not reachable") && !strings.Contains(err.Error(), "block_not_reachable") {
|
||||
t.Fatalf("unexpected error — want cycle-bounded 'not reachable', got: %v", err)
|
||||
}
|
||||
// blk_y should be fetched exactly once. Registering just one stub for it
|
||||
// already enforces an upper bound (httpmock errors on extra calls), so if
|
||||
// the walk looped more than once the test harness would fail differently.
|
||||
if blkYStub.CapturedHeaders == nil && blkYStub.CapturedBody == nil {
|
||||
t.Errorf("expected the walk to fetch blk_y once; stub was not hit")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ func TestDocMediaUploadDryRunUsesMultipartForLargeFile(t *testing.T) {
|
||||
t.Fatalf("set --parent-node: %v", err)
|
||||
}
|
||||
|
||||
dry := decodeDocDryRun(t, MediaUpload.DryRun(context.Background(), common.TestNewRuntimeContext(cmd, nil)))
|
||||
dry := decodeDocDryRun(t, DocMediaUpload.DryRun(context.Background(), common.TestNewRuntimeContext(cmd, nil)))
|
||||
if dry.Description != "chunked media upload (files > 20MB)" {
|
||||
t.Fatalf("dry-run description = %q", dry.Description)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
var MediaUpload = common.Shortcut{
|
||||
var DocMediaUpload = common.Shortcut{
|
||||
Service: "docs",
|
||||
Command: "+media-upload",
|
||||
Description: "Upload media file (image/attachment) to a document block",
|
||||
@@ -22,8 +22,8 @@ var MediaUpload = common.Shortcut{
|
||||
AuthTypes: []string{"user", "bot"},
|
||||
Flags: []common.Flag{
|
||||
{Name: "file", Desc: "local file path (files > 20MB use multipart upload automatically)", Required: true},
|
||||
{Name: "parent-type", Desc: "parent type: docx_image | docx_file", Required: true},
|
||||
{Name: "parent-node", Desc: "parent node ID (block_id)", Required: true},
|
||||
{Name: "parent-type", Desc: "parent type: docx_image | docx_file | whiteboard", Required: true},
|
||||
{Name: "parent-node", Desc: "parent node ID (block_id for docx, board_token for whiteboard)", Required: true},
|
||||
{Name: "doc-id", Desc: "document ID (for drive_route_token)"},
|
||||
},
|
||||
DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
|
||||
@@ -5,6 +5,7 @@ package doc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
@@ -62,6 +63,9 @@ var DocsUpdate = common.Shortcut{
|
||||
if needsSelection[mode] && selEllipsis == "" && selTitle == "" {
|
||||
return common.FlagErrorf("--%s mode requires --selection-with-ellipsis or --selection-by-title", mode)
|
||||
}
|
||||
if err := validateSelectionByTitle(selTitle); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -89,12 +93,22 @@ var DocsUpdate = common.Shortcut{
|
||||
Set("mcp_tool", "update-doc").Set("args", args)
|
||||
},
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
mode := runtime.Str("mode")
|
||||
markdown := runtime.Str("markdown")
|
||||
|
||||
// Static semantic checks run before the MCP call so users see
|
||||
// warnings even if the subsequent request fails. They never block
|
||||
// execution — the update still proceeds.
|
||||
for _, w := range docsUpdateWarnings(mode, markdown) {
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "warning: %s\n", w)
|
||||
}
|
||||
|
||||
args := map[string]interface{}{
|
||||
"doc_id": runtime.Str("doc"),
|
||||
"mode": runtime.Str("mode"),
|
||||
"mode": mode,
|
||||
}
|
||||
if v := runtime.Str("markdown"); v != "" {
|
||||
args["markdown"] = v
|
||||
if markdown != "" {
|
||||
args["markdown"] = markdown
|
||||
}
|
||||
if v := runtime.Str("selection-with-ellipsis"); v != "" {
|
||||
args["selection_with_ellipsis"] = v
|
||||
@@ -156,3 +170,17 @@ func normalizeBoardTokens(raw interface{}) []string {
|
||||
return []string{}
|
||||
}
|
||||
}
|
||||
|
||||
func validateSelectionByTitle(title string) error {
|
||||
if title == "" {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(title)
|
||||
if strings.Contains(trimmed, "\n") || strings.Contains(trimmed, "\r") {
|
||||
return common.FlagErrorf("--selection-by-title must be a single heading line (for example: '## Section')")
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
return nil
|
||||
}
|
||||
return common.FlagErrorf("--selection-by-title must include markdown heading prefix '#'. Example: --selection-by-title '## Section'")
|
||||
}
|
||||
|
||||
281
shortcuts/doc/docs_update_check.go
Normal file
281
shortcuts/doc/docs_update_check.go
Normal file
@@ -0,0 +1,281 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package doc
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// docsUpdateWarnings returns a list of human-readable warnings for a
|
||||
// `docs +update` invocation based on static analysis of the mode and
|
||||
// Markdown payload. The warnings describe CLI/MCP contract edges that
|
||||
// commonly surprise users; the update is still executed — callers
|
||||
// decide whether to stop at a warning.
|
||||
//
|
||||
// Both checks ignore fenced code blocks (```…``` and ~~~…~~~, with up
|
||||
// to 3 leading spaces per CommonMark §4.5), inline code spans, and
|
||||
// backslash-escaped emphasis markers so that literal Markdown content
|
||||
// embedded in code samples or escaped prose does not produce false
|
||||
// positives.
|
||||
//
|
||||
// Warnings emitted (current):
|
||||
//
|
||||
// 1. replace_* modes do not split blocks. A Markdown payload containing
|
||||
// a blank line (\n\n) in prose implies the caller expects multiple
|
||||
// paragraphs, but replace_range / replace_all only swap in-block
|
||||
// text. The resulting block will contain the blank line as literal
|
||||
// text and appear as a single paragraph in the UI.
|
||||
//
|
||||
// 2. Lark does not round-trip bold+italic. Six shapes are detected:
|
||||
// ***text*** ___text___
|
||||
// **_text_** __*text*__
|
||||
// _**text**_ *__text__*
|
||||
// Lark stores only one of the two emphases (usually italic), silently
|
||||
// dropping the other. The user wanted both; they will get one.
|
||||
func docsUpdateWarnings(mode, markdown string) []string {
|
||||
var warnings []string
|
||||
if w := checkDocsUpdateReplaceMultilineMarkdown(mode, markdown); w != "" {
|
||||
warnings = append(warnings, w)
|
||||
}
|
||||
if w := checkDocsUpdateBoldItalic(markdown); w != "" {
|
||||
warnings = append(warnings, w)
|
||||
}
|
||||
return warnings
|
||||
}
|
||||
|
||||
// checkDocsUpdateReplaceMultilineMarkdown flags markdown that contains a
|
||||
// blank-line paragraph break outside fenced code blocks under a replace_*
|
||||
// mode. Blank lines inside code fences are literal content and don't
|
||||
// imply paragraph semantics, so they are deliberately ignored.
|
||||
func checkDocsUpdateReplaceMultilineMarkdown(mode, markdown string) string {
|
||||
if mode != "replace_range" && mode != "replace_all" {
|
||||
return ""
|
||||
}
|
||||
// A CR/LF-robust check: both "\n\n" and "\r\n\r\n" count as paragraph
|
||||
// separators. We normalize line endings once before detection.
|
||||
normalized := strings.ReplaceAll(markdown, "\r\n", "\n")
|
||||
if !proseHasBlankLine(normalized) {
|
||||
return ""
|
||||
}
|
||||
return "--mode=" + mode + " does not split a block into multiple paragraphs; " +
|
||||
"the blank line in --markdown will render as literal text. " +
|
||||
"For multiple paragraphs, use --mode=delete_range followed by --mode=insert_before."
|
||||
}
|
||||
|
||||
// combinedEmphasisPatterns holds the six documented combined-emphasis shapes
|
||||
// that Lark downgrades to a single emphasis. Each entry pairs a regex with a
|
||||
// short shape label for the warning message. The two forms per shape (with
|
||||
// and without `[^…]*?`) are there because the lazy quantifier needs at least
|
||||
// one non-delimiter character to match; single-rune payloads (e.g. `***X***`)
|
||||
// take the second alternation.
|
||||
var combinedEmphasisPatterns = []struct {
|
||||
shape string
|
||||
re *regexp.Regexp
|
||||
}{
|
||||
// Bold+italic with a single delimiter char.
|
||||
{"***text***", regexp.MustCompile(`\*\*\*\S[^*]*?\S\*\*\*|\*\*\*\S\*\*\*`)},
|
||||
{"___text___", regexp.MustCompile(`___\S[^_]*?\S___|___\S___`)},
|
||||
|
||||
// Bold wrapping italic (asterisk outside).
|
||||
{"**_text_**", regexp.MustCompile(`\*\*_\S[^_*]*?\S_\*\*|\*\*_\S_\*\*`)},
|
||||
{"__*text*__", regexp.MustCompile(`__\*\S[^_*]*?\S\*__|__\*\S\*__`)},
|
||||
|
||||
// Italic wrapping bold (asterisk inside).
|
||||
{"_**text**_", regexp.MustCompile(`_\*\*\S[^_*]*?\S\*\*_|_\*\*\S\*\*_`)},
|
||||
{"*__text__*", regexp.MustCompile(`\*__\S[^_*]*?\S__\*|\*__\S__\*`)},
|
||||
}
|
||||
|
||||
// checkDocsUpdateBoldItalic flags Markdown emphases that attempt to
|
||||
// combine bold and italic in a way Lark cannot represent. Fenced code
|
||||
// blocks, inline code spans, and backslash-escaped emphasis markers are
|
||||
// stripped first so that literal markdown examples ("here is a
|
||||
// `***keyword***` to flag") do not trigger the warning.
|
||||
func checkDocsUpdateBoldItalic(markdown string) string {
|
||||
if markdown == "" {
|
||||
return ""
|
||||
}
|
||||
sanitized := stripEscapedEmphasisMarkers(stripMarkdownCodeRegions(markdown))
|
||||
for _, p := range combinedEmphasisPatterns {
|
||||
if p.re.MatchString(sanitized) {
|
||||
return "Lark does not support combined bold+italic markers " +
|
||||
"(e.g. ***text***, ___text___, **_text_**, _**text**_, __*text*__, *__text__*); " +
|
||||
"the emphasis will be downgraded to either bold or italic. " +
|
||||
"Split into two separate emphases or drop one of them."
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// proseHasBlankLine reports whether markdown contains a blank line outside
|
||||
// of fenced code blocks. Blank lines inside ```...``` or ~~~...~~~ fences
|
||||
// are code content, not paragraph separators, and must not trip the
|
||||
// "replace_* cannot split paragraphs" warning.
|
||||
//
|
||||
// A blank line counts only when it sits between two non-blank boundaries
|
||||
// (other prose, or a fence open/close). A trailing empty line at EOF is
|
||||
// not treated as "\n\n".
|
||||
func proseHasBlankLine(markdown string) bool {
|
||||
lines := strings.Split(markdown, "\n")
|
||||
inFence := false
|
||||
var fenceMarker string
|
||||
for i, line := range lines {
|
||||
if inFence {
|
||||
if isCodeFenceClose(line, fenceMarker) {
|
||||
inFence = false
|
||||
fenceMarker = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
if marker := codeFenceOpenMarker(line); marker != "" {
|
||||
inFence = true
|
||||
fenceMarker = marker
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(line) == "" && i > 0 && i+1 < len(lines) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripMarkdownCodeRegions returns markdown with fenced code blocks blanked
|
||||
// out and inline code spans replaced by whitespace of equivalent length.
|
||||
// Byte offsets outside the masked regions are preserved, so follow-on
|
||||
// regex matches still point at real prose positions.
|
||||
func stripMarkdownCodeRegions(markdown string) string {
|
||||
lines := strings.Split(markdown, "\n")
|
||||
inFence := false
|
||||
var fenceMarker string
|
||||
for i, line := range lines {
|
||||
if inFence {
|
||||
if isCodeFenceClose(line, fenceMarker) {
|
||||
inFence = false
|
||||
fenceMarker = ""
|
||||
}
|
||||
lines[i] = ""
|
||||
continue
|
||||
}
|
||||
if marker := codeFenceOpenMarker(line); marker != "" {
|
||||
inFence = true
|
||||
fenceMarker = marker
|
||||
lines[i] = ""
|
||||
continue
|
||||
}
|
||||
lines[i] = maskInlineCodeSpans(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// maskInlineCodeSpans replaces the byte ranges of any inline code spans in
|
||||
// line with space characters of equal length. Uses scanInlineCodeSpans from
|
||||
// markdown_fix.go, which implements the CommonMark §6.1 matching-backtick-run
|
||||
// rule (so “ `a`b` “ is a single span).
|
||||
func maskInlineCodeSpans(line string) string {
|
||||
spans := scanInlineCodeSpans(line)
|
||||
if len(spans) == 0 {
|
||||
return line
|
||||
}
|
||||
var sb strings.Builder
|
||||
pos := 0
|
||||
for _, loc := range spans {
|
||||
sb.WriteString(line[pos:loc[0]])
|
||||
sb.WriteString(strings.Repeat(" ", loc[1]-loc[0]))
|
||||
pos = loc[1]
|
||||
}
|
||||
sb.WriteString(line[pos:])
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// stripEscapedEmphasisMarkers removes backslash-escaped '*' and '_' so the
|
||||
// bold/italic regexes don't treat literal sequences like `\***text***` as
|
||||
// real combined emphasis. CommonMark renders "\*" as a literal "*" with no
|
||||
// emphasis semantics; dropping the escape + its target from the detection
|
||||
// input keeps the heuristic aligned with what the renderer actually does.
|
||||
//
|
||||
// Known limitation: a doubled backslash escape ("\\" followed by a real
|
||||
// emphasis marker, e.g. `\\***text***`) renders as a literal backslash
|
||||
// followed by genuine combined emphasis, but this strip is not a proper
|
||||
// parser and will instead consume the second backslash as the opener for
|
||||
// another escape. That hides the real emphasis from the check, producing
|
||||
// a false negative. Practical impact is small (this shape is rare in the
|
||||
// kind of AI-Agent prompts we target) and the alternative — a full
|
||||
// CommonMark escape parser — is not worth the code surface here.
|
||||
func stripEscapedEmphasisMarkers(s string) string {
|
||||
s = strings.ReplaceAll(s, `\*`, "")
|
||||
s = strings.ReplaceAll(s, `\_`, "")
|
||||
return s
|
||||
}
|
||||
|
||||
// codeFenceOpenMarker returns the fence marker (e.g. "```" or "~~~~") if
|
||||
// line opens a fenced code block, otherwise "". Applies CommonMark §4.5
|
||||
// rules: up to 3 leading spaces are tolerated; 4+ leading spaces (or any
|
||||
// leading tab, which expands to 4 columns) make the line an indented code
|
||||
// block rather than a fence.
|
||||
func codeFenceOpenMarker(line string) string {
|
||||
body, ok := fenceIndentOK(line)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(body, "```"):
|
||||
return leadingRun(body, '`')
|
||||
case strings.HasPrefix(body, "~~~"):
|
||||
return leadingRun(body, '~')
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isCodeFenceClose reports whether line closes a fence opened with marker.
|
||||
// Per CommonMark §4.5 the closer must use the same fence character, be at
|
||||
// least as long as the opener, sit within 0..3 leading spaces, and carry
|
||||
// no info-string text.
|
||||
func isCodeFenceClose(line, marker string) bool {
|
||||
if marker == "" {
|
||||
return false
|
||||
}
|
||||
body, ok := fenceIndentOK(line)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
fenceChar := marker[0]
|
||||
run := leadingRun(body, fenceChar)
|
||||
if len(run) < len(marker) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(body[len(run):]) == ""
|
||||
}
|
||||
|
||||
// fenceIndentOK returns (bodyWithoutLeadingSpaces, true) when line has
|
||||
// 0..3 leading spaces and no leading tab — i.e. the indentation is
|
||||
// permissible for a CommonMark fence. Returns ("", false) otherwise
|
||||
// (4+ leading spaces or any tab), meaning the line must be treated as
|
||||
// indented code block content rather than a fence boundary.
|
||||
func fenceIndentOK(line string) (string, bool) {
|
||||
for i := 0; i < len(line) && i < 4; i++ {
|
||||
switch line[i] {
|
||||
case ' ':
|
||||
continue
|
||||
case '\t':
|
||||
return "", false
|
||||
default:
|
||||
return line[i:], true
|
||||
}
|
||||
}
|
||||
// Reached index 4 without hitting a non-space character: too indented.
|
||||
if len(line) >= 4 {
|
||||
return "", false
|
||||
}
|
||||
// Line shorter than 4 chars and all spaces — still valid (empty content).
|
||||
return "", true
|
||||
}
|
||||
|
||||
// leadingRun returns the longest prefix of s made up of the byte c.
|
||||
func leadingRun(s string, c byte) string {
|
||||
i := 0
|
||||
for i < len(s) && s[i] == c {
|
||||
i++
|
||||
}
|
||||
return s[:i]
|
||||
}
|
||||
375
shortcuts/doc/docs_update_check_test.go
Normal file
375
shortcuts/doc/docs_update_check_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package doc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckDocsUpdateReplaceMultilineMarkdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
markdown string
|
||||
wantHint bool
|
||||
}{
|
||||
{
|
||||
name: "replace_range with blank line emits hint",
|
||||
mode: "replace_range",
|
||||
markdown: "new paragraph\n\nsecond paragraph",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "replace_all with blank line emits hint",
|
||||
mode: "replace_all",
|
||||
markdown: "first\n\nsecond",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "replace_range single paragraph is fine",
|
||||
mode: "replace_range",
|
||||
markdown: "just a single paragraph of text",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "single newline is not a paragraph break",
|
||||
mode: "replace_range",
|
||||
markdown: "line one\nline two",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "crlf paragraph break is also detected",
|
||||
mode: "replace_range",
|
||||
markdown: "first\r\n\r\nsecond",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "other modes are not flagged",
|
||||
mode: "insert_before",
|
||||
markdown: "first\n\nsecond",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "append mode is not flagged",
|
||||
mode: "append",
|
||||
markdown: "first\n\nsecond",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "empty markdown is fine",
|
||||
mode: "replace_range",
|
||||
markdown: "",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// The check must ignore blank lines inside fenced code; otherwise
|
||||
// a user replacing one block with a legitimate code sample that
|
||||
// contains blank lines would see a spurious warning.
|
||||
name: "blank line inside backtick fenced code is not flagged",
|
||||
mode: "replace_range",
|
||||
markdown: "```\nline1\n\nline2\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "blank line inside tilde fenced code is not flagged",
|
||||
mode: "replace_range",
|
||||
markdown: "~~~\ncode line one\n\ncode line two\n~~~",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// Mixed prose + fenced code: any blank line in prose still wins,
|
||||
// even if the fenced content also contains blanks.
|
||||
name: "blank line in prose outside fence still flags even when fence has blanks",
|
||||
mode: "replace_range",
|
||||
markdown: "first paragraph\n\nsecond paragraph\n\n```\ncode\n\nmore\n```",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// Fenced code with no blank lines inside must not trip on the
|
||||
// fence markers themselves.
|
||||
name: "fenced code with no blank lines does not flag",
|
||||
mode: "replace_range",
|
||||
markdown: "prose before\n```go\nfmt.Println(\"hi\")\n```\nprose after",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// CommonMark §4.5: the closing fence must be ≥ opening fence length.
|
||||
// A 4-backtick close for a 3-backtick open is a legitimate way to
|
||||
// embed triple-backticks in a code sample; the check must see the
|
||||
// fence as properly closed and not treat the rest of the document
|
||||
// as still-inside-fence.
|
||||
name: "longer close marker closes fence correctly",
|
||||
mode: "replace_range",
|
||||
markdown: "```\nsome code\n````\n\nprose paragraph after",
|
||||
wantHint: true, // the blank line AFTER the fence is real prose
|
||||
},
|
||||
{
|
||||
name: "longer close marker still hides blank line inside fence",
|
||||
mode: "replace_range",
|
||||
markdown: "```\nbefore\n\nafter\n````",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// 4+ leading spaces make the line an indented code block, not a
|
||||
// fence open. The "fence"-looking line is code content; the
|
||||
// surrounding blank must still be detected.
|
||||
name: "four-space indented fence-like line is not a fence open",
|
||||
mode: "replace_range",
|
||||
markdown: "first paragraph\n\n ```\n code\n ```",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// A tab in the leading whitespace is always ≥4 columns and thus
|
||||
// forces indented-code-block semantics.
|
||||
name: "tab-indented fence-like line is not a fence open",
|
||||
mode: "replace_range",
|
||||
markdown: "first paragraph\n\n\t```\n\tcode\n\t```",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// 3 leading spaces is still within the fence-tolerance window.
|
||||
name: "three-space indented fence is still a fence",
|
||||
mode: "replace_range",
|
||||
markdown: " ```\ncode\n\nmore\n ```",
|
||||
wantHint: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := checkDocsUpdateReplaceMultilineMarkdown(tt.mode, tt.markdown)
|
||||
hasHint := got != ""
|
||||
if hasHint != tt.wantHint {
|
||||
t.Fatalf("checkDocsUpdateReplaceMultilineMarkdown(%q, %q) = %q, wantHint=%v",
|
||||
tt.mode, tt.markdown, got, tt.wantHint)
|
||||
}
|
||||
if tt.wantHint && (!strings.Contains(got, "delete_range") || !strings.Contains(got, "insert_before")) {
|
||||
t.Errorf("hint should suggest delete_range/insert_before remediation, got: %s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDocsUpdateBoldItalic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantHint bool
|
||||
}{
|
||||
{
|
||||
name: "triple asterisks flagged",
|
||||
input: "a ***key insight*** here",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "triple asterisks single char flagged",
|
||||
input: "a ***X*** here",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "bold wrapping underscore italic flagged",
|
||||
input: "note: **_important_** detail",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "underscore wrapping double asterisk flagged",
|
||||
input: "note: _**important**_ detail",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "plain bold is fine",
|
||||
input: "this is **bold** text",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "plain italic is fine",
|
||||
input: "this is *italic* or _italic_ text",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "horizontal rule is not flagged",
|
||||
input: "paragraph\n\n---\n\nnext",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "bold followed by italic with space is not flagged",
|
||||
input: "**bold** and *italic*",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "empty input is fine",
|
||||
input: "",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// The emphasis check must not fire on literal Markdown samples
|
||||
// inside a fenced code block — the canonical use case is docs
|
||||
// authors pasting tutorials that demonstrate these exact patterns.
|
||||
name: "triple asterisks inside backtick fenced code is not flagged",
|
||||
input: "example:\n```\nthe shape ***keyword*** downgrades\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "underscore-bold inside fenced code is not flagged",
|
||||
input: "example:\n```markdown\nuse **_strong italic_** carefully\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "bold-underscore inside fenced code is not flagged",
|
||||
input: "example:\n~~~\n_**outside-underscore**_ is a bad shape\n~~~",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "triple asterisks inside inline code span is not flagged",
|
||||
input: "the literal `***text***` marker is just a sample",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "underscore-bold inside inline code is not flagged",
|
||||
input: "the shape `**_italic_**` would downgrade, but only if it were real",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "escaped triple asterisks rendered as literal text is not flagged",
|
||||
input: `the literal \***text*** with escaped opener`,
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "escaped bold inside underscore-italic is not flagged",
|
||||
input: `shape \*\*_text_\*\* is literal, not emphasis`,
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
// Real emphasis outside the code span must still be detected —
|
||||
// the strip step must not over-sanitize.
|
||||
name: "real triple asterisks outside inline code still flags",
|
||||
input: "real ***strong*** and literal `***keyword***` — the first one counts",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "real triple asterisks outside fenced code still flags",
|
||||
input: "real ***strong***\n\n```\nliteral ***keyword*** in code\n```",
|
||||
wantHint: true,
|
||||
},
|
||||
// --- Triple-underscore combined emphasis: ___text___ ---
|
||||
{
|
||||
name: "triple underscores flagged",
|
||||
input: "a ___key insight___ here",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "triple underscores single char flagged",
|
||||
input: "a ___X___ here",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "triple underscores inside fenced code not flagged",
|
||||
input: "sample:\n```\nuse ___keyword___ carefully\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "triple underscores inside inline code not flagged",
|
||||
input: "the literal `___phrase___` marker",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "escaped triple underscores not flagged",
|
||||
input: `literal \___phrase___ with escaped opener`,
|
||||
wantHint: false,
|
||||
},
|
||||
// --- Underscore-bold wrapping asterisk-italic: __*text*__ ---
|
||||
{
|
||||
name: "underscore-bold wrapping asterisk-italic flagged",
|
||||
input: "note: __*important*__ text",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "underscore-bold wrapping asterisk-italic inside fenced code not flagged",
|
||||
input: "```\nnote: __*important*__ sample\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
{
|
||||
name: "underscore-bold wrapping asterisk-italic inside inline code not flagged",
|
||||
input: "literal `__*important*__` marker",
|
||||
wantHint: false,
|
||||
},
|
||||
// --- Asterisk-italic wrapping underscore-bold: *__text__* ---
|
||||
{
|
||||
name: "asterisk-italic wrapping underscore-bold flagged",
|
||||
input: "note: *__phrase__* text",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
name: "asterisk-italic wrapping underscore-bold inside fenced code not flagged",
|
||||
input: "```md\nnote: *__phrase__* sample\n```",
|
||||
wantHint: false,
|
||||
},
|
||||
// --- Positive tests: real emphasis in prose coexisting with fake in code ---
|
||||
{
|
||||
// Underscore-variant in prose must still fire when an asterisk
|
||||
// variant appears inside a code span — verifies the strip does
|
||||
// not over-sanitize across the six regex alternatives.
|
||||
name: "real triple underscores outside inline code still flag when asterisk variant is in code",
|
||||
input: "real ___strong___ and literal `***shape***` in code",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// Longer close fence closes properly; real ***emphasis*** after
|
||||
// the fence must fire.
|
||||
name: "real emphasis after a fence closed by longer marker still flags",
|
||||
input: "```\nliteral ***phrase*** in code\n````\n\nand then real ***phrase*** after",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// 4-space indented "```" is an indented code block, not a fence
|
||||
// open. The fence helper should refuse it; emphasis outside the
|
||||
// (non-existent) fence must still be detected.
|
||||
name: "four-space indented fence-like line does not open a fence for the emphasis check",
|
||||
input: "prose\n\n ```\n not a fence\n ```\n\nreal ***strong*** here",
|
||||
wantHint: true,
|
||||
},
|
||||
{
|
||||
// 3-space indented fence is valid per CommonMark. Emphasis inside
|
||||
// must be sanitized away, so the check must not fire.
|
||||
name: "three-space indented fence still hides triple-asterisk inside",
|
||||
input: " ```\n literal ***text*** inside\n ```",
|
||||
wantHint: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := checkDocsUpdateBoldItalic(tt.input)
|
||||
hasHint := got != ""
|
||||
if hasHint != tt.wantHint {
|
||||
t.Fatalf("checkDocsUpdateBoldItalic(%q) = %q, wantHint=%v", tt.input, got, tt.wantHint)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocsUpdateWarningsAggregates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Both flags trigger: replace_range with blank line AND triple-asterisk.
|
||||
warnings := docsUpdateWarnings("replace_range", "***opening***\n\nsecond paragraph")
|
||||
if len(warnings) != 2 {
|
||||
t.Fatalf("expected 2 warnings, got %d: %v", len(warnings), warnings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocsUpdateWarningsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Clean markdown in a non-replace mode produces zero warnings.
|
||||
warnings := docsUpdateWarnings("insert_before", "plain paragraph text")
|
||||
if len(warnings) != 0 {
|
||||
t.Fatalf("expected no warnings, got: %v", warnings)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,14 @@
|
||||
package doc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
func TestIsWhiteboardCreateMarkdown(t *testing.T) {
|
||||
@@ -30,6 +36,59 @@ func TestIsWhiteboardCreateMarkdown(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeBoardTokens(t *testing.T) {
|
||||
// Codecov patch includes normalizeBoardTokens in this PR's diff because
|
||||
// the PR base predates #569 where this helper landed; the previously-
|
||||
// untested string and default arms are what keep patch coverage under the
|
||||
// threshold. These cases lock the fallback paths so any future caller
|
||||
// that passes a plain string or a non-slice token bag gets a stable shape.
|
||||
|
||||
t.Run("nil raw returns empty slice", func(t *testing.T) {
|
||||
got := normalizeBoardTokens(nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice, got %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("already-typed string slice passes through", func(t *testing.T) {
|
||||
in := []string{"a", "b"}
|
||||
got := normalizeBoardTokens(in)
|
||||
if !reflect.DeepEqual(got, in) {
|
||||
t.Fatalf("got %#v, want %#v", got, in)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("interface slice skips non-string and empty string items", func(t *testing.T) {
|
||||
got := normalizeBoardTokens([]interface{}{"keep", "", 42, "also"})
|
||||
want := []string{"keep", "also"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %#v, want %#v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single string wraps into one-item slice", func(t *testing.T) {
|
||||
got := normalizeBoardTokens("solo")
|
||||
want := []string{"solo"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %#v, want %#v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty string returns empty slice, not one-item slice", func(t *testing.T) {
|
||||
got := normalizeBoardTokens("")
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice for empty string input, got %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported type falls through to empty slice", func(t *testing.T) {
|
||||
got := normalizeBoardTokens(42)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice for non-string/non-slice input, got %#v", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeDocsUpdateResult(t *testing.T) {
|
||||
t.Run("adds empty board_tokens when whiteboard creation response omits it", func(t *testing.T) {
|
||||
result := map[string]interface{}{
|
||||
@@ -76,3 +135,201 @@ func TestNormalizeDocsUpdateResult(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateSelectionByTitle(t *testing.T) {
|
||||
t.Run("empty title passes", func(t *testing.T) {
|
||||
if err := validateSelectionByTitle(""); err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("heading style title passes", func(t *testing.T) {
|
||||
if err := validateSelectionByTitle("## 第二章"); err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain text title fails with guidance", func(t *testing.T) {
|
||||
err := validateSelectionByTitle("第二章")
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error")
|
||||
}
|
||||
if got := err.Error(); got == "" || !containsAll(got, "selection-by-title", "heading prefix") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multi-line heading still fails", func(t *testing.T) {
|
||||
err := validateSelectionByTitle("## 第二章\n## 第三章")
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error")
|
||||
}
|
||||
if got := err.Error(); got == "" || !containsAll(got, "single heading line") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multi-line title fails", func(t *testing.T) {
|
||||
err := validateSelectionByTitle("第二章\n第三章")
|
||||
if err == nil {
|
||||
t.Fatalf("expected validation error")
|
||||
}
|
||||
if got := err.Error(); got == "" || !containsAll(got, "single heading line") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func containsAll(s string, tokens ...string) bool {
|
||||
for _, token := range tokens {
|
||||
if !strings.Contains(s, token) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestDocsUpdateValidate exercises the Validate closure directly so the new
|
||||
// --selection-by-title integration point (call site in Validate) is covered,
|
||||
// not just the underlying validateSelectionByTitle helper. Without this the
|
||||
// three lines added to the closure show up as untested in the patch coverage
|
||||
// report even though the helper itself is at 100%.
|
||||
func TestDocsUpdateValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
flags map[string]string
|
||||
boolFlag string // name of optional bool flag to set (currently unused; placeholder for future flags)
|
||||
wantErr string // substring; empty = expect nil error
|
||||
}{
|
||||
{
|
||||
// Happy path that exercises the new selection-by-title call site
|
||||
// with a valid heading — reaches the `return nil` branch.
|
||||
name: "heading-style selection-by-title passes",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"markdown": "new body",
|
||||
"selection-by-title": "## Section",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Exercises the error-return branch of the new call site.
|
||||
name: "plain-text selection-by-title is rejected with heading-prefix guidance",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"markdown": "new body",
|
||||
"selection-by-title": "第二章",
|
||||
},
|
||||
wantErr: "heading prefix",
|
||||
},
|
||||
{
|
||||
// Exercises the multi-line guard inside validateSelectionByTitle
|
||||
// through the Validate call path.
|
||||
name: "multi-line selection-by-title is rejected as not a single heading",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"markdown": "new body",
|
||||
"selection-by-title": "## a\n## b",
|
||||
},
|
||||
wantErr: "single heading line",
|
||||
},
|
||||
{
|
||||
// Invalid mode — proves the earlier mode check still fires before
|
||||
// reaching the new selection-by-title check, so the new code
|
||||
// doesn't accidentally mask pre-existing validation.
|
||||
name: "invalid mode is still rejected first",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "bogus",
|
||||
"selection-by-title": "## Section",
|
||||
},
|
||||
wantErr: "invalid --mode",
|
||||
},
|
||||
{
|
||||
// Both selection forms supplied — proves the mutual-exclusion
|
||||
// check still fires before the new selection-by-title check.
|
||||
name: "conflicting selection flags are rejected before title validation",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"markdown": "body",
|
||||
"selection-with-ellipsis": "start...end",
|
||||
"selection-by-title": "## Section",
|
||||
},
|
||||
wantErr: "mutually exclusive",
|
||||
},
|
||||
{
|
||||
// Non-delete_range modes require --markdown; this exercises the
|
||||
// pre-existing empty-markdown branch that sits between the mode
|
||||
// check and the new selection-by-title check. Covering it keeps
|
||||
// patch coverage above codecov's threshold for this closure.
|
||||
name: "non-delete_range mode without --markdown is rejected",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"selection-by-title": "## Section",
|
||||
},
|
||||
wantErr: "requires --markdown",
|
||||
},
|
||||
{
|
||||
// needsSelection[mode] is true for replace_range but neither
|
||||
// selection flag is set — covers the "requires selection" branch
|
||||
// that precedes the new call site.
|
||||
name: "replace_range without any selection flag is rejected",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "replace_range",
|
||||
"markdown": "body",
|
||||
},
|
||||
wantErr: "requires --selection-with-ellipsis or --selection-by-title",
|
||||
},
|
||||
{
|
||||
// delete_range has no markdown requirement and no selection
|
||||
// requirement when neither is supplied is actually ok under the
|
||||
// current rules (delete_range still needs selection per
|
||||
// needsSelection, but the test proves the markdown-empty guard
|
||||
// does not fire for delete_range specifically).
|
||||
name: "delete_range without --markdown but with selection passes markdown check",
|
||||
flags: map[string]string{
|
||||
"doc": "doxcnABCDEF",
|
||||
"mode": "delete_range",
|
||||
"selection-by-title": "## Section",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := &cobra.Command{Use: "docs +update"}
|
||||
cmd.Flags().String("doc", "", "")
|
||||
cmd.Flags().String("mode", "", "")
|
||||
cmd.Flags().String("markdown", "", "")
|
||||
cmd.Flags().String("selection-with-ellipsis", "", "")
|
||||
cmd.Flags().String("selection-by-title", "", "")
|
||||
cmd.Flags().String("new-title", "", "")
|
||||
for k, v := range tt.flags {
|
||||
if err := cmd.Flags().Set(k, v); err != nil {
|
||||
t.Fatalf("set --%s=%q: %v", k, v, err)
|
||||
}
|
||||
}
|
||||
|
||||
rt := common.TestNewRuntimeContext(cmd, nil)
|
||||
err := DocsUpdate.Validate(context.Background(), rt)
|
||||
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("error %q does not contain %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ package doc
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// fixExportedMarkdown applies post-processing to Lark-exported Markdown to
|
||||
@@ -15,24 +17,29 @@ import (
|
||||
// and strips redundant ** from ATX headings. Applied only outside fenced
|
||||
// code blocks, and skips inline code spans.
|
||||
//
|
||||
// 2. fixSetextAmbiguity: inserts a blank line before any "---" that immediately
|
||||
// 2. normalizeNestedListIndentation: rewrites space-pair-indented nested list
|
||||
// markers to tab-indented markers. This avoids nested ordered list items
|
||||
// being flattened or interpreted as plain text/code on re-import.
|
||||
//
|
||||
// 3. fixSetextAmbiguity: inserts a blank line before any "---" that immediately
|
||||
// follows a non-empty line, preventing it from being parsed as a Setext H2.
|
||||
// Applied only outside fenced code blocks.
|
||||
//
|
||||
// 3. fixBlockquoteHardBreaks: inserts a blank blockquote line (">") between
|
||||
// 4. fixBlockquoteHardBreaks: inserts a blank blockquote line (">") between
|
||||
// consecutive blockquote content lines so create-doc preserves line breaks.
|
||||
// Applied only outside fenced code blocks.
|
||||
//
|
||||
// 4. fixTopLevelSoftbreaks: inserts a blank line between adjacent non-empty
|
||||
// 5. fixTopLevelSoftbreaks: inserts a blank line between adjacent non-empty
|
||||
// lines at the top level and inside content containers (callout,
|
||||
// quote-container, lark-td). Code fences are left untouched, and
|
||||
// consecutive list items / continuations are not separated.
|
||||
//
|
||||
// 5. fixCalloutEmoji: replaces named emoji aliases (e.g. emoji="warning") with
|
||||
// 6. fixCalloutEmoji: replaces named emoji aliases (e.g. emoji="warning") with
|
||||
// actual Unicode emoji characters that create-doc understands. Applied only
|
||||
// outside fenced code blocks.
|
||||
func fixExportedMarkdown(md string) string {
|
||||
md = applyOutsideCodeFences(md, fixBoldSpacing)
|
||||
md = applyOutsideCodeFences(md, normalizeNestedListIndentation)
|
||||
md = applyOutsideCodeFences(md, fixSetextAmbiguity)
|
||||
md = applyOutsideCodeFences(md, fixBlockquoteHardBreaks)
|
||||
md = fixTopLevelSoftbreaks(md)
|
||||
@@ -106,20 +113,21 @@ func fixBlockquoteHardBreaks(md string) string {
|
||||
return strings.Join(out, "\n")
|
||||
}
|
||||
|
||||
// fixBoldSpacing fixes two issues with bold markers exported by Lark:
|
||||
// fixBoldSpacing normalizes emphasis markers exported by Lark while preserving
|
||||
// inline code spans:
|
||||
//
|
||||
// 1. Trailing whitespace before closing **: "**text **" → "**text**"
|
||||
// CommonMark requires no space before a closing delimiter; otherwise the
|
||||
// ** is rendered as literal text.
|
||||
// 1. Removes leading whitespace after opening ** and * delimiters:
|
||||
// "** text**" → "**text**", "* text*" → "*text*"
|
||||
//
|
||||
// 2. Redundant bold in ATX headings: "# **text**" → "# text"
|
||||
// Headings are already bold, so the inner ** is visually redundant and
|
||||
// some renderers display the markers literally.
|
||||
// 2. Removes trailing whitespace before closing ** and * delimiters:
|
||||
// "**text **" → "**text**", "*text *" → "*text*"
|
||||
//
|
||||
// Both fixes skip inline code spans to avoid modifying literal code content.
|
||||
// 3. Removes redundant bold around an entire ATX heading:
|
||||
// "# **text**" → "# text"
|
||||
//
|
||||
// The bold and italic spacing fixes only run on non-code segments so literal
|
||||
// code content is left unchanged.
|
||||
var (
|
||||
boldTrailingSpaceRe = regexp.MustCompile(`(\*\*\S[^*]*?)\s+(\*\*)`)
|
||||
italicTrailingSpaceRe = regexp.MustCompile(`(\*\S[^*]*?)\s+(\*)`)
|
||||
// headingBoldRe uses [^*]+ (no asterisks) to avoid mismatching headings
|
||||
// that contain multiple disjoint bold spans such as "# **foo** and **bar**".
|
||||
headingBoldRe = regexp.MustCompile(`(?m)^(#{1,6})\s+\*\*([^*]+)\*\*\s*$`)
|
||||
@@ -182,38 +190,116 @@ func scanInlineCodeSpans(line string) [][2]int {
|
||||
// fixBoldSpacingLine applies bold/italic trailing-space fixes to a single line,
|
||||
// skipping content inside inline code spans to avoid corrupting literal code.
|
||||
// ATX heading lines are also skipped here because headingBoldRe in fixBoldSpacing
|
||||
// handles them separately and boldTrailingSpaceRe can misfire on headings with
|
||||
// multiple disjoint bold spans (e.g. "# **foo** and **bar**").
|
||||
// handles them separately, keeping heading-only normalization isolated from the
|
||||
// inline emphasis spacing scanner below.
|
||||
func fixBoldSpacingLine(line string) string {
|
||||
if atxHeadingRe.MatchString(line) {
|
||||
return line
|
||||
}
|
||||
spans := scanInlineCodeSpans(line)
|
||||
if len(spans) == 0 {
|
||||
line = boldTrailingSpaceRe.ReplaceAllString(line, "$1$2")
|
||||
line = italicTrailingSpaceRe.ReplaceAllString(line, "$1$2")
|
||||
return line
|
||||
return fixEmphasisSpacingSegment(line)
|
||||
}
|
||||
var sb strings.Builder
|
||||
pos := 0
|
||||
for _, loc := range spans {
|
||||
// Process the non-code segment before this inline code span.
|
||||
seg := line[pos:loc[0]]
|
||||
seg = boldTrailingSpaceRe.ReplaceAllString(seg, "$1$2")
|
||||
seg = italicTrailingSpaceRe.ReplaceAllString(seg, "$1$2")
|
||||
sb.WriteString(seg)
|
||||
sb.WriteString(fixEmphasisSpacingSegment(seg))
|
||||
// Preserve inline code span as-is.
|
||||
sb.WriteString(line[loc[0]:loc[1]])
|
||||
pos = loc[1]
|
||||
}
|
||||
// Remaining non-code segment after the last code span.
|
||||
seg := line[pos:]
|
||||
seg = boldTrailingSpaceRe.ReplaceAllString(seg, "$1$2")
|
||||
seg = italicTrailingSpaceRe.ReplaceAllString(seg, "$1$2")
|
||||
sb.WriteString(seg)
|
||||
sb.WriteString(fixEmphasisSpacingSegment(line[pos:]))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// fixEmphasisSpacingSegment trims only the whitespace immediately inside simple
|
||||
// *...* and **...** spans. It deliberately ignores runs of 3+ asterisks and
|
||||
// any candidate whose payload contains another asterisk so nested emphasis-like
|
||||
// text remains untouched. When both inner sides contain whitespace, single-rune
|
||||
// payloads are preserved as literal text (for example "* x *" and "** x **").
|
||||
func fixEmphasisSpacingSegment(seg string) string {
|
||||
if !strings.Contains(seg, "*") {
|
||||
return seg
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
pos := 0
|
||||
for pos < len(seg) {
|
||||
openStart, openEnd, ok := nextAsteriskRun(seg, pos)
|
||||
if !ok {
|
||||
sb.WriteString(seg[pos:])
|
||||
break
|
||||
}
|
||||
|
||||
sb.WriteString(seg[pos:openStart])
|
||||
|
||||
markerLen := openEnd - openStart
|
||||
if markerLen != 1 && markerLen != 2 {
|
||||
sb.WriteString(seg[openStart:openEnd])
|
||||
pos = openEnd
|
||||
continue
|
||||
}
|
||||
|
||||
closeStart, closeEnd, ok := nextAsteriskRun(seg, openEnd)
|
||||
if !ok || closeEnd-closeStart != markerLen {
|
||||
sb.WriteString(seg[openStart:openEnd])
|
||||
pos = openEnd
|
||||
continue
|
||||
}
|
||||
|
||||
payload := seg[openEnd:closeStart]
|
||||
normalized, shouldNormalize := normalizeEmphasisPayload(payload)
|
||||
if !shouldNormalize {
|
||||
sb.WriteString(seg[openStart:closeEnd])
|
||||
pos = closeEnd
|
||||
continue
|
||||
}
|
||||
|
||||
marker := seg[openStart:openEnd]
|
||||
sb.WriteString(marker)
|
||||
sb.WriteString(normalized)
|
||||
sb.WriteString(marker)
|
||||
pos = closeEnd
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func nextAsteriskRun(s string, start int) (runStart, runEnd int, ok bool) {
|
||||
for i := start; i < len(s); i++ {
|
||||
if s[i] != '*' {
|
||||
continue
|
||||
}
|
||||
j := i
|
||||
for j < len(s) && s[j] == '*' {
|
||||
j++
|
||||
}
|
||||
return i, j, true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func normalizeEmphasisPayload(payload string) (string, bool) {
|
||||
trimmedLeft := strings.TrimLeftFunc(payload, unicode.IsSpace)
|
||||
trimmed := strings.TrimRightFunc(trimmedLeft, unicode.IsSpace)
|
||||
if trimmed == "" {
|
||||
return payload, false
|
||||
}
|
||||
|
||||
hasLeadingSpace := len(trimmedLeft) != len(payload)
|
||||
hasTrailingSpace := len(trimmed) != len(trimmedLeft)
|
||||
if !hasLeadingSpace && !hasTrailingSpace {
|
||||
return payload, true
|
||||
}
|
||||
|
||||
if hasLeadingSpace && hasTrailingSpace && utf8.RuneCountInString(trimmed) == 1 {
|
||||
return payload, false
|
||||
}
|
||||
return trimmed, true
|
||||
}
|
||||
|
||||
var setextRe = regexp.MustCompile(`(?m)^([^\n]+)\n(-{3,}\s*$)`)
|
||||
|
||||
func fixSetextAmbiguity(md string) string {
|
||||
@@ -291,6 +377,44 @@ var contentContainers = [][2]string{
|
||||
// indented (nested) items.
|
||||
var listItemRe = regexp.MustCompile(`^[ \t]*([-*+]|\d+[.)]) `)
|
||||
|
||||
// nestedListIndentRe matches nested list item markers indented with pairs of
|
||||
// spaces. We rewrite those space pairs to tabs because some downstream
|
||||
// round-trip paths treat multi-space indented ordered items as flat items or
|
||||
// literal text, while tab indentation remains nested and avoids 4-space code
|
||||
// block ambiguity.
|
||||
var nestedListIndentRe = regexp.MustCompile(`^( {2,})([-*+]|\d+[.)]) `)
|
||||
|
||||
func normalizeNestedListIndentation(md string) string {
|
||||
lines := strings.Split(md, "\n")
|
||||
for i, line := range lines {
|
||||
matches := nestedListIndentRe.FindStringSubmatch(line)
|
||||
if len(matches) != 3 {
|
||||
continue
|
||||
}
|
||||
if !hasPreviousNonBlankListItem(lines, i) {
|
||||
continue
|
||||
}
|
||||
indent := matches[1]
|
||||
if len(indent)%2 != 0 {
|
||||
continue
|
||||
}
|
||||
tabs := strings.Repeat("\t", len(indent)/2)
|
||||
lines[i] = tabs + line[len(indent):]
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func hasPreviousNonBlankListItem(lines []string, index int) bool {
|
||||
for i := index - 1; i >= 0; i-- {
|
||||
trimmed := strings.TrimSpace(lines[i])
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
return listItemRe.MatchString(lines[i])
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isListItemOrContinuation returns true for lines that are part of a list:
|
||||
// either a list item marker line or an indented continuation of a list item.
|
||||
// This is used to prevent blank lines being inserted between tight list lines,
|
||||
|
||||
287
shortcuts/doc/markdown_fix_hardening_test.go
Normal file
287
shortcuts/doc/markdown_fix_hardening_test.go
Normal file
@@ -0,0 +1,287 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package doc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFixExportedMarkdownIdempotent asserts the core promise of the exported
|
||||
// markdown pipeline: applying the fixes twice produces the same result as
|
||||
// applying them once. Round-trip formatting relies on this invariant, so any
|
||||
// transform that keeps rewriting its own output would break fetch → edit →
|
||||
// update → fetch stability.
|
||||
func TestFixExportedMarkdownIdempotent(t *testing.T) {
|
||||
fixtures := map[string]string{
|
||||
"kitchen sink": strings.Join([]string{
|
||||
"# **Title**",
|
||||
"paragraph one",
|
||||
"paragraph two",
|
||||
"**bold ** and * italic*",
|
||||
"",
|
||||
"> q1",
|
||||
"> q2",
|
||||
"",
|
||||
"1. parent",
|
||||
" 1. child",
|
||||
" 1. grandchild",
|
||||
"",
|
||||
"<callout emoji=\"warning\">",
|
||||
"callout body line 1",
|
||||
"callout body line 2",
|
||||
"</callout>",
|
||||
"",
|
||||
"some text",
|
||||
"---",
|
||||
"",
|
||||
"```go",
|
||||
"// code content with markdown-like shapes must survive as-is",
|
||||
"**foo **",
|
||||
"* hello*",
|
||||
" 1. nested",
|
||||
"> q",
|
||||
"---",
|
||||
"```",
|
||||
"",
|
||||
}, "\n"),
|
||||
|
||||
"cjk content": strings.Join([]string{
|
||||
"# **测试标题**",
|
||||
"段落一",
|
||||
"段落二",
|
||||
"**有用性 ** and * 关键 *",
|
||||
"",
|
||||
"1. 父项",
|
||||
" 1. 子项",
|
||||
"",
|
||||
}, "\n"),
|
||||
|
||||
"nested containers": strings.Join([]string{
|
||||
"<callout emoji=\"info\">",
|
||||
"line a",
|
||||
"line b",
|
||||
"</callout>",
|
||||
"",
|
||||
"<quote-container>",
|
||||
"quoted 1",
|
||||
"quoted 2",
|
||||
"</quote-container>",
|
||||
"",
|
||||
}, "\n"),
|
||||
}
|
||||
|
||||
for name, fixture := range fixtures {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
once := fixExportedMarkdown(fixture)
|
||||
twice := fixExportedMarkdown(once)
|
||||
if once != twice {
|
||||
t.Errorf("fixExportedMarkdown is not idempotent for %q\nfirst pass:\n%s\nsecond pass:\n%s",
|
||||
name, once, twice)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFixExportedMarkdownPreservesFencedCodeByteForByte packs a fenced code
|
||||
// block with content that every individual transform in the pipeline would
|
||||
// normally rewrite, and asserts the fence content comes out byte-for-byte
|
||||
// identical. This is the pipeline's strongest invariant — users' code samples
|
||||
// must never be silently modified by a formatting pass.
|
||||
func TestFixExportedMarkdownPreservesFencedCodeByteForByte(t *testing.T) {
|
||||
// Every line below is something at least one transform would touch if it
|
||||
// appeared outside a fence. None of it must change.
|
||||
dangerous := strings.Join([]string{
|
||||
"**foo **", // fixBoldSpacing — trailing space bold
|
||||
"* hello*", // fixBoldSpacing — leading space italic
|
||||
"# **heading**", // fixBoldSpacing — redundant heading bold
|
||||
"para1", // fixTopLevelSoftbreaks — adjacent paragraphs
|
||||
"para2",
|
||||
"> q1", // fixBlockquoteHardBreaks — blockquote pair
|
||||
"> q2",
|
||||
"some text", // fixSetextAmbiguity — text before ---
|
||||
"---",
|
||||
" 1. nested", // normalizeNestedListIndentation
|
||||
`<callout emoji="warning">`, // fixCalloutEmoji — emoji alias
|
||||
}, "\n")
|
||||
|
||||
// Wrap the dangerous content in a triple-backtick fence and surround with
|
||||
// content so the pipeline has adjacent regions to potentially touch.
|
||||
input := "before\n\n```\n" + dangerous + "\n```\n\nafter\n"
|
||||
|
||||
got := fixExportedMarkdown(input)
|
||||
|
||||
// Extract the fence content from the output and compare to the input fence
|
||||
// content byte-for-byte.
|
||||
gotFence, ok := extractFirstFenceContent(got)
|
||||
if !ok {
|
||||
t.Fatalf("fixExportedMarkdown output lost its fenced code block:\n%s", got)
|
||||
}
|
||||
if gotFence != dangerous {
|
||||
t.Errorf("fenced code content was modified\nwant (bytes): %q\ngot (bytes): %q",
|
||||
dangerous, gotFence)
|
||||
}
|
||||
}
|
||||
|
||||
// extractFirstFenceContent returns the inner text of the first triple-backtick
|
||||
// fenced code block it finds, or ("", false) if none is present.
|
||||
func extractFirstFenceContent(md string) (string, bool) {
|
||||
const fence = "```"
|
||||
open := strings.Index(md, fence)
|
||||
if open < 0 {
|
||||
return "", false
|
||||
}
|
||||
// Skip the fence marker and its info-string line.
|
||||
rest := md[open+len(fence):]
|
||||
lineEnd := strings.Index(rest, "\n")
|
||||
if lineEnd < 0 {
|
||||
return "", false
|
||||
}
|
||||
rest = rest[lineEnd+1:]
|
||||
close := strings.Index(rest, "\n"+fence)
|
||||
if close < 0 {
|
||||
return "", false
|
||||
}
|
||||
return rest[:close], true
|
||||
}
|
||||
|
||||
// TestFixExportedMarkdownPreservesCRLF feeds CRLF-terminated markdown (Windows
|
||||
// line endings) through the pipeline and asserts that line endings are
|
||||
// preserved AND the emphasis/heading transforms still apply — neither
|
||||
// silently-LF-normalized nor passed through unchanged.
|
||||
func TestFixExportedMarkdownPreservesCRLF(t *testing.T) {
|
||||
lf := "# **Title**\nparagraph one\nparagraph two\n**bold **\n"
|
||||
crlf := strings.ReplaceAll(lf, "\n", "\r\n")
|
||||
|
||||
got := fixExportedMarkdown(crlf)
|
||||
|
||||
// Transforms must still fire: heading bold stripped, trailing-space bold trimmed.
|
||||
if strings.Contains(got, "**Title**") {
|
||||
t.Errorf("heading bold not stripped on CRLF input:\n%q", got)
|
||||
}
|
||||
if strings.Contains(got, "**bold **") {
|
||||
t.Errorf("trailing-space bold not fixed on CRLF input:\n%q", got)
|
||||
}
|
||||
// CRLF line endings must survive — we don't want to silently normalize a
|
||||
// Windows author's document to LF.
|
||||
if !strings.Contains(got, "\r\n") {
|
||||
t.Errorf("CRLF line endings were normalized away:\n%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFixExportedMarkdownTransformInteractions covers shapes where more than
|
||||
// one transform fires on the same input. Each transform is individually tested
|
||||
// elsewhere; these cases guard against composition regressions.
|
||||
func TestFixExportedMarkdownTransformInteractions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantContains []string // substrings that must be present after fixes
|
||||
wantAbsent []string // substrings that must be absent after fixes
|
||||
}{
|
||||
{
|
||||
name: "nested list item with trailing-space bold",
|
||||
input: "1. parent\n 1. **child **\n",
|
||||
wantContains: []string{
|
||||
"\t1.", // nested indent converted to tab
|
||||
"**child**", // trailing space trimmed
|
||||
},
|
||||
wantAbsent: []string{
|
||||
" 1.", // original two-space indent gone
|
||||
"**child **", // original trailing space gone
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "paragraph followed by list",
|
||||
input: "paragraph\n- item a\n- item b\n",
|
||||
wantContains: []string{
|
||||
"paragraph\n\n- item a", // blank line inserted at text-to-list transition
|
||||
},
|
||||
wantAbsent: []string{
|
||||
"\n\n\n", // no triple newline
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "callout containing list with emphasis",
|
||||
input: "<callout emoji=\"info\">\n- **item **\n- another\n</callout>\n",
|
||||
wantContains: []string{
|
||||
"**item**", // trailing-space bold fixed inside callout
|
||||
},
|
||||
wantAbsent: []string{
|
||||
"**item **",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "heading followed by paragraph with bold",
|
||||
input: "# **Title**\nbody **text **\n",
|
||||
wantContains: []string{
|
||||
"# Title", // heading bold stripped
|
||||
"body **text**", // paragraph bold trimmed, not stripped
|
||||
},
|
||||
wantAbsent: []string{
|
||||
"# **Title**",
|
||||
"body **text **",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := fixExportedMarkdown(tt.input)
|
||||
for _, want := range tt.wantContains {
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("want substring %q not found in output:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
for _, unwanted := range tt.wantAbsent {
|
||||
if strings.Contains(got, unwanted) {
|
||||
t.Errorf("unwanted substring %q still present in output:\n%s", unwanted, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeNestedListIndentationDocumentedSkips locks in the deliberate
|
||||
// "do nothing" branches of normalizeNestedListIndentation. Each case below is
|
||||
// a shape the function intentionally does not rewrite; if a future change to
|
||||
// the heuristic flips one of these, we want the regression to be visible in
|
||||
// the test diff rather than silently changing user documents.
|
||||
func TestNormalizeNestedListIndentationDocumentedSkips(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
// want is identical to input — we are asserting "no change".
|
||||
}{
|
||||
{
|
||||
name: "three-space indent (odd) under list item stays unchanged",
|
||||
input: "1. parent\n 1. child",
|
||||
},
|
||||
{
|
||||
name: "five-space indent (odd) under list item stays unchanged",
|
||||
input: "- parent\n - deep",
|
||||
},
|
||||
{
|
||||
name: "two-space indent without a parent list item stays unchanged",
|
||||
input: "plain paragraph\n - not nested",
|
||||
},
|
||||
{
|
||||
name: "blank-line-separated loose-list sibling stays unchanged",
|
||||
input: "1. a\n\n 1. b",
|
||||
},
|
||||
{
|
||||
name: "four-space indented code block under list item stays unchanged",
|
||||
input: "- parent\n\n 1. code sample",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeNestedListIndentation(tt.input)
|
||||
if got != tt.input {
|
||||
t.Errorf("normalizeNestedListIndentation unexpectedly rewrote documented-skip input\ninput: %q\ngot: %q", tt.input, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,56 @@ func TestFixBoldSpacing(t *testing.T) {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "leading space after opening bold",
|
||||
input: "** hello**",
|
||||
want: "**hello**",
|
||||
},
|
||||
{
|
||||
name: "leading space after opening italic",
|
||||
input: "* hello*",
|
||||
want: "*hello*",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing spaces inside bold are collapsed",
|
||||
input: "** hello **",
|
||||
want: "**hello**",
|
||||
},
|
||||
{
|
||||
name: "leading and trailing spaces inside italic are collapsed",
|
||||
input: "* hello *",
|
||||
want: "*hello*",
|
||||
},
|
||||
{
|
||||
name: "multiple spaced italic spans on one line are each collapsed",
|
||||
input: "* a* * b*",
|
||||
want: "*a* *b*",
|
||||
},
|
||||
{
|
||||
name: "ambiguous italic span stays literal",
|
||||
input: "2 * x * y",
|
||||
want: "2 * x * y",
|
||||
},
|
||||
{
|
||||
name: "ambiguous bold span stays literal",
|
||||
input: "2 ** x ** y",
|
||||
want: "2 ** x ** y",
|
||||
},
|
||||
{
|
||||
name: "single-rune italic with spaces on both sides stays literal",
|
||||
input: "* x *",
|
||||
want: "* x *",
|
||||
},
|
||||
{
|
||||
name: "single-rune bold with spaces on both sides stays literal",
|
||||
input: "** x **",
|
||||
want: "** x **",
|
||||
},
|
||||
{
|
||||
name: "triple-asterisk near miss stays literal",
|
||||
input: "*** hello**",
|
||||
want: "*** hello**",
|
||||
},
|
||||
{
|
||||
name: "trailing space before closing bold",
|
||||
input: "**hello **",
|
||||
@@ -54,6 +104,16 @@ func TestFixBoldSpacing(t *testing.T) {
|
||||
input: "**foo ** and `**bar **`",
|
||||
want: "**foo** and `**bar **`",
|
||||
},
|
||||
{
|
||||
name: "inline code with spaced italic stays literal while outside span is fixed",
|
||||
input: "`* hello *` and * hello *",
|
||||
want: "`* hello *` and *hello*",
|
||||
},
|
||||
{
|
||||
name: "opening space inside text tag fixed",
|
||||
input: `<text color="red">** Helpful - 有用性:**</text>`,
|
||||
want: `<text color="red">**Helpful - 有用性:**</text>`,
|
||||
},
|
||||
{
|
||||
name: "double-backtick inline code not modified",
|
||||
input: "``**hello **`` and **world **",
|
||||
@@ -222,6 +282,53 @@ func TestFixTopLevelSoftbreaks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeNestedListIndentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nested ordered list uses tabs instead of space pairs",
|
||||
input: "1. parent\n 1. child\n 1. grandchild",
|
||||
want: "1. parent\n\t1. child\n\t\t1. grandchild",
|
||||
},
|
||||
{
|
||||
name: "nested mixed list markers use tabs instead of space pairs",
|
||||
input: "- parent\n - child\n 1. grandchild",
|
||||
want: "- parent\n\t- child\n\t\t1. grandchild",
|
||||
},
|
||||
{
|
||||
name: "top-level list unchanged",
|
||||
input: "1. parent\n2. sibling",
|
||||
want: "1. parent\n2. sibling",
|
||||
},
|
||||
{
|
||||
name: "indented top-level marker without parent list stays unchanged",
|
||||
input: "paragraph\n\n 1. item",
|
||||
want: "paragraph\n\n 1. item",
|
||||
},
|
||||
{
|
||||
name: "blank-line-separated loose-list sibling stays unchanged",
|
||||
input: "1. a\n\n 1. b",
|
||||
want: "1. a\n\n 1. b",
|
||||
},
|
||||
{
|
||||
name: "indented code block inside list item stays unchanged",
|
||||
input: "- parent\n\n 1. code",
|
||||
want: "- parent\n\n 1. code",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeNestedListIndentation(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("normalizeNestedListIndentation(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixExportedMarkdown(t *testing.T) {
|
||||
// End-to-end: all fixes applied together
|
||||
input := "# **Title**\nparagraph one\nparagraph two\n**bold **\n> q1\n> q2\nsome text\n---"
|
||||
|
||||
@@ -13,6 +13,7 @@ func Shortcuts() []common.Shortcut {
|
||||
DocsFetch,
|
||||
DocsUpdate,
|
||||
DocMediaInsert,
|
||||
DocMediaUpload,
|
||||
DocMediaPreview,
|
||||
DocMediaDownload,
|
||||
}
|
||||
|
||||
150
shortcuts/drive/drive_apply_permission.go
Normal file
150
shortcuts/drive/drive_apply_permission.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package drive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
)
|
||||
|
||||
// permApplyTypes is the authoritative list of type values the apply-permission
|
||||
// endpoint accepts for its required `type` query parameter.
|
||||
var permApplyTypes = []string{
|
||||
"doc", "sheet", "file", "wiki", "bitable", "docx",
|
||||
"mindnote", "slides",
|
||||
}
|
||||
|
||||
// permApplyURLMarkers maps document URL path markers to the `type` value the
|
||||
// apply-permission endpoint expects. Markers are disjoint strings (each begins
|
||||
// with "/" and ends with "/"), so a simple substring scan disambiguates them.
|
||||
var permApplyURLMarkers = []struct {
|
||||
Marker string
|
||||
Type string
|
||||
}{
|
||||
{"/wiki/", "wiki"},
|
||||
{"/docx/", "docx"},
|
||||
{"/sheets/", "sheet"},
|
||||
{"/base/", "bitable"},
|
||||
{"/bitable/", "bitable"},
|
||||
{"/file/", "file"},
|
||||
{"/mindnote/", "mindnote"},
|
||||
{"/slides/", "slides"},
|
||||
{"/doc/", "doc"},
|
||||
}
|
||||
|
||||
// resolvePermApplyTarget extracts (token, type) from a user-supplied --token
|
||||
// value that may be either a bare token or a full document URL, plus an
|
||||
// optional explicit --type. Explicit --type wins over URL inference.
|
||||
func resolvePermApplyTarget(raw, explicitType string) (token, docType string, err error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", "", output.ErrValidation("--token is required")
|
||||
}
|
||||
|
||||
if strings.Contains(raw, "://") {
|
||||
for _, m := range permApplyURLMarkers {
|
||||
if tok, ok := extractURLToken(raw, m.Marker); ok {
|
||||
token = tok
|
||||
if explicitType == "" {
|
||||
docType = m.Type
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
return "", "", output.ErrValidation(
|
||||
"could not infer token from URL %q: supported paths are /docx/, /sheets/, /base/, /bitable/, /file/, /wiki/, /doc/, /mindnote/, /slides/. Pass a bare token with --type instead if the URL shape is unusual",
|
||||
raw,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
token = raw
|
||||
}
|
||||
|
||||
if explicitType != "" {
|
||||
docType = explicitType
|
||||
}
|
||||
if docType == "" {
|
||||
return "", "", output.ErrValidation(
|
||||
"--type is required when --token is a bare token; accepted values: %s",
|
||||
strings.Join(permApplyTypes, ", "),
|
||||
)
|
||||
}
|
||||
return token, docType, nil
|
||||
}
|
||||
|
||||
// DriveApplyPermission applies to the document owner for view or edit access
|
||||
// on behalf of the invoking user. Matches the open-apis endpoint
|
||||
// /open-apis/drive/v1/permissions/:token/members/apply.
|
||||
//
|
||||
// The backend accepts only user_access_token for this endpoint, so the
|
||||
// shortcut declares AuthTypes: ["user"] — bot identity is rejected up-front.
|
||||
var DriveApplyPermission = common.Shortcut{
|
||||
Service: "drive",
|
||||
Command: "+apply-permission",
|
||||
Description: "Apply to the document owner for view or edit permission on a doc/sheet/file/wiki/bitable/docx/mindnote/slides",
|
||||
Risk: "write",
|
||||
Scopes: []string{"docs:permission.member:apply"},
|
||||
AuthTypes: []string{"user"},
|
||||
Flags: []common.Flag{
|
||||
{Name: "token", Desc: "target token or document URL (docx/sheets/base/file/wiki/doc/mindnote/slides)", Required: true},
|
||||
{Name: "type", Desc: "target type; auto-inferred from URL when omitted", Enum: permApplyTypes},
|
||||
{Name: "perm", Desc: "permission to request", Required: true, Enum: []string{"view", "edit"}},
|
||||
{Name: "remark", Desc: "optional note shown on the request card sent to the owner"},
|
||||
},
|
||||
Validate: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
_, _, err := resolvePermApplyTarget(runtime.Str("token"), runtime.Str("type"))
|
||||
return err
|
||||
},
|
||||
DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
token, docType, err := resolvePermApplyTarget(runtime.Str("token"), runtime.Str("type"))
|
||||
if err != nil {
|
||||
return common.NewDryRunAPI().Set("error", err.Error())
|
||||
}
|
||||
body := buildPermApplyBody(runtime)
|
||||
return common.NewDryRunAPI().
|
||||
Desc("Apply to document owner for access").
|
||||
POST("/open-apis/drive/v1/permissions/:token/members/apply").
|
||||
Params(map[string]interface{}{"type": docType}).
|
||||
Body(body).
|
||||
Set("token", token)
|
||||
},
|
||||
Execute: func(ctx context.Context, runtime *common.RuntimeContext) error {
|
||||
token, docType, err := resolvePermApplyTarget(runtime.Str("token"), runtime.Str("type"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body := buildPermApplyBody(runtime)
|
||||
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "Requesting %s access on %s %s...\n",
|
||||
runtime.Str("perm"), docType, common.MaskToken(token))
|
||||
|
||||
data, err := runtime.CallAPI("POST",
|
||||
fmt.Sprintf("/open-apis/drive/v1/permissions/%s/members/apply", validate.EncodePathSegment(token)),
|
||||
map[string]interface{}{"type": docType},
|
||||
body,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.Out(data, nil)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// buildPermApplyBody returns the request body with the caller-supplied perm
|
||||
// and optional remark. remark is omitted entirely when empty so the server
|
||||
// doesn't render an empty note on the request card.
|
||||
func buildPermApplyBody(runtime *common.RuntimeContext) map[string]interface{} {
|
||||
body := map[string]interface{}{"perm": runtime.Str("perm")}
|
||||
if s := runtime.Str("remark"); s != "" {
|
||||
body["remark"] = s
|
||||
}
|
||||
return body
|
||||
}
|
||||
238
shortcuts/drive/drive_apply_permission_test.go
Normal file
238
shortcuts/drive/drive_apply_permission_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package drive
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/httpmock"
|
||||
)
|
||||
|
||||
// ── resolvePermApplyTarget unit tests ────────────────────────────────────────
|
||||
|
||||
func TestResolvePermApplyTarget_BareTokenNeedsType(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, err := resolvePermApplyTarget("bareToken", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "--type is required") {
|
||||
t.Fatalf("expected --type required error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePermApplyTarget_BareTokenWithType(t *testing.T) {
|
||||
t.Parallel()
|
||||
token, docType, err := resolvePermApplyTarget("bareToken", "docx")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "bareToken" || docType != "docx" {
|
||||
t.Fatalf("got token=%q type=%q, want bareToken/docx", token, docType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePermApplyTarget_URLInference(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
wantTok string
|
||||
wantType string
|
||||
}{
|
||||
{"docx", "https://example.feishu.cn/docx/doxTok123?from=share", "doxTok123", "docx"},
|
||||
{"sheets", "https://example.feishu.cn/sheets/shtTok456?sheet=abc", "shtTok456", "sheet"},
|
||||
{"base", "https://example.feishu.cn/base/bscTok789", "bscTok789", "bitable"},
|
||||
{"bitable", "https://example.feishu.cn/bitable/bscTok789", "bscTok789", "bitable"},
|
||||
{"file", "https://example.feishu.cn/file/boxTok111", "boxTok111", "file"},
|
||||
{"wiki", "https://example.feishu.cn/wiki/wikTok222", "wikTok222", "wiki"},
|
||||
{"legacy doc", "https://example.feishu.cn/doc/docTok333", "docTok333", "doc"},
|
||||
{"mindnote", "https://example.feishu.cn/mindnote/mnTok444", "mnTok444", "mindnote"},
|
||||
{"slides", "https://example.feishu.cn/slides/slTok666", "slTok666", "slides"},
|
||||
}
|
||||
for _, temp := range tests {
|
||||
tt := temp
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
token, docType, err := resolvePermApplyTarget(tt.raw, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != tt.wantTok || docType != tt.wantType {
|
||||
t.Fatalf("got (%q,%q), want (%q,%q)", token, docType, tt.wantTok, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePermApplyTarget_ExplicitTypeOverridesURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Even though the URL marker is /docx/, an explicit --type wins.
|
||||
token, docType, err := resolvePermApplyTarget("https://example.feishu.cn/docx/doxTok123", "wiki")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "doxTok123" || docType != "wiki" {
|
||||
t.Fatalf("got (%q,%q), want (doxTok123,wiki)", token, docType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePermApplyTarget_UnrecognizedURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, err := resolvePermApplyTarget("https://example.feishu.cn/unknown/xyz", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "could not infer token") {
|
||||
t.Fatalf("expected infer error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePermApplyTarget_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, err := resolvePermApplyTarget(" ", "docx")
|
||||
if err == nil || !strings.Contains(err.Error(), "--token is required") {
|
||||
t.Fatalf("expected token required error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── shortcut integration tests ──────────────────────────────────────────────
|
||||
|
||||
func TestDriveApplyPermission_ValidateMissingToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission", "--perm", "view", "--type", "docx", "--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "token") {
|
||||
t.Fatalf("expected token error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveApplyPermission_ValidateRejectsBadPerm(t *testing.T) {
|
||||
t.Parallel()
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission",
|
||||
"--token", "doxTok",
|
||||
"--type", "docx",
|
||||
"--perm", "full_access",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err == nil || !strings.Contains(err.Error(), "--perm") {
|
||||
t.Fatalf("expected perm enum error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveApplyPermission_DryRunInfersTypeFromURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
f, stdout, _, _ := cmdutil.TestFactory(t, driveTestConfig())
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission",
|
||||
"--token", "https://example.feishu.cn/sheets/shtTok?sheet=abc",
|
||||
"--perm", "edit",
|
||||
"--remark", "please",
|
||||
"--dry-run", "--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
out := stdout.String()
|
||||
for _, want := range []string{
|
||||
"/open-apis/drive/v1/permissions/shtTok/members/apply",
|
||||
`"POST"`,
|
||||
`"sheet"`,
|
||||
`"edit"`,
|
||||
`"please"`,
|
||||
`"shtTok"`,
|
||||
} {
|
||||
if !strings.Contains(out, want) {
|
||||
t.Fatalf("dry-run output missing %q:\n%s", want, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveApplyPermission_ExecuteSuccess(t *testing.T) {
|
||||
f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
// Stub URL includes "?type=docx" — the stub only matches when the request
|
||||
// URL contains that query, so this doubles as an assertion that the
|
||||
// shortcut emits the type query parameter.
|
||||
stub := &httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/permissions/doxTok123/members/apply?type=docx",
|
||||
Body: map[string]interface{}{
|
||||
"code": 0, "msg": "success",
|
||||
"data": map[string]interface{}{"applied": true},
|
||||
},
|
||||
}
|
||||
reg.Register(stub)
|
||||
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission",
|
||||
"--token", "doxTok123",
|
||||
"--type", "docx",
|
||||
"--perm", "view",
|
||||
"--as", "user",
|
||||
}, f, stdout)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(stub.CapturedBody, &body); err != nil {
|
||||
t.Fatalf("parse body: %v", err)
|
||||
}
|
||||
if body["perm"] != "view" {
|
||||
t.Fatalf("perm = %v, want view", body["perm"])
|
||||
}
|
||||
if _, hasRemark := body["remark"]; hasRemark {
|
||||
t.Fatalf("remark should be omitted when empty, got: %v", body["remark"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveApplyPermission_ExecuteNotApplicableHint(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/permissions/doxTok/members/apply",
|
||||
Status: 400,
|
||||
Body: map[string]interface{}{
|
||||
"code": 1063007, "msg": "request not applicable",
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission",
|
||||
"--token", "doxTok",
|
||||
"--type", "docx",
|
||||
"--perm", "view",
|
||||
"--as", "user",
|
||||
}, f, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 1063007")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not applicable") {
|
||||
t.Fatalf("expected surfaced server message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveApplyPermission_ExecuteRateLimitHint(t *testing.T) {
|
||||
f, _, _, reg := cmdutil.TestFactory(t, driveTestConfig())
|
||||
reg.Register(&httpmock.Stub{
|
||||
Method: "POST",
|
||||
URL: "/open-apis/drive/v1/permissions/doxTok/members/apply",
|
||||
Status: 429,
|
||||
Body: map[string]interface{}{
|
||||
"code": 1063006, "msg": "quota exceeded",
|
||||
},
|
||||
})
|
||||
|
||||
err := mountAndRunDrive(t, DriveApplyPermission, []string{
|
||||
"+apply-permission",
|
||||
"--token", "doxTok",
|
||||
"--type", "docx",
|
||||
"--perm", "view",
|
||||
"--as", "user",
|
||||
}, f, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 1063006")
|
||||
}
|
||||
}
|
||||
@@ -19,5 +19,6 @@ func Shortcuts() []common.Shortcut {
|
||||
DriveMove,
|
||||
DriveDelete,
|
||||
DriveTaskResult,
|
||||
DriveApplyPermission,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ func TestShortcutsIncludesExpectedCommands(t *testing.T) {
|
||||
"+move",
|
||||
"+delete",
|
||||
"+task_result",
|
||||
"+apply-permission",
|
||||
}
|
||||
|
||||
if len(got) != len(want) {
|
||||
|
||||
@@ -152,8 +152,9 @@ func ResolveSenderNames(runtime *common.RuntimeContext, messages []map[string]in
|
||||
// This API has lighter permission requirements and works with user identity
|
||||
// even when the target user is not in the app's visible range.
|
||||
// Response uses "users" (not "items") and "user_id" (not "open_id").
|
||||
// The basic_batch endpoint caps user_ids at 10 per request.
|
||||
func batchResolveByBasicContact(runtime *common.RuntimeContext, missingIDs []string, nameMap map[string]string) {
|
||||
const batchSize = 50
|
||||
const batchSize = 10
|
||||
for i := 0; i < len(missingIDs); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(missingIDs) {
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
package convertlib
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -170,6 +172,57 @@ func TestResolveSenderNames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchResolveByBasicContactRespectsAPILimit(t *testing.T) {
|
||||
// basic_batch allows at most 10 user_ids per request. Given 25 missing IDs,
|
||||
// expect three requests with sizes 10 / 10 / 5.
|
||||
var batchSizes []int
|
||||
runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if !strings.Contains(req.URL.Path, "/open-apis/contact/v3/users/basic_batch") {
|
||||
return nil, fmt.Errorf("unexpected path: %s", req.URL.Path)
|
||||
}
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, _ := payload["user_ids"].([]interface{})
|
||||
if len(userIDs) > 10 {
|
||||
t.Fatalf("batch exceeded API limit: size = %d", len(userIDs))
|
||||
}
|
||||
batchSizes = append(batchSizes, len(userIDs))
|
||||
|
||||
users := make([]interface{}, 0, len(userIDs))
|
||||
for _, raw := range userIDs {
|
||||
id, _ := raw.(string)
|
||||
users = append(users, map[string]interface{}{
|
||||
"user_id": id,
|
||||
"name": "name-" + id,
|
||||
})
|
||||
}
|
||||
return convertlibJSONResponse(200, map[string]interface{}{
|
||||
"code": 0,
|
||||
"data": map[string]interface{}{"users": users},
|
||||
}), nil
|
||||
}))
|
||||
|
||||
missingIDs := make([]string, 25)
|
||||
for i := range missingIDs {
|
||||
missingIDs[i] = fmt.Sprintf("ou_%02d", i)
|
||||
}
|
||||
nameMap := map[string]string{}
|
||||
batchResolveByBasicContact(runtime, missingIDs, nameMap)
|
||||
|
||||
if want := []int{10, 10, 5}; !reflect.DeepEqual(batchSizes, want) {
|
||||
t.Fatalf("batch sizes = %v, want %v", batchSizes, want)
|
||||
}
|
||||
if len(nameMap) != 25 {
|
||||
t.Fatalf("resolved name count = %d, want 25", len(nameMap))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSenderNamesAPIFailure(t *testing.T) {
|
||||
runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
switch {
|
||||
|
||||
@@ -262,7 +262,7 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) {
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
target := filepath.Join("nested", "resource.bin")
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", target)
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", target, true)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -308,7 +308,7 @@ func TestDownloadIMResourceToPathImageUsesSingleRequestWithoutRange(t *testing.T
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
gotPath, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_img", "img_123", "image", "image")
|
||||
gotPath, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_img", "img_123", "image", "image", true)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -342,7 +342,7 @@ func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", "out.bin")
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", "out.bin", true)
|
||||
if err == nil || !strings.Contains(err.Error(), "HTTP 403: denied") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -372,7 +372,7 @@ func TestDownloadIMResourceToPathRetriesNetworkError(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
target := "out.bin"
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry", "file_retry", "file", target)
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry", "file_retry", "file", target, true)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -408,7 +408,7 @@ func TestDownloadIMResourceToPathRetrySecondAttemptSuccess(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
target := "out.bin"
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry2", "file_retry2", "file", target)
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry2", "file_retry2", "file", target, true)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -444,7 +444,7 @@ func TestDownloadIMResourceToPathRetryContextCanceled(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
target := "out.bin"
|
||||
_, _, err := downloadIMResourceToPath(ctx, runtime, "om_cancel", "file_cancel", "file", target)
|
||||
_, _, err := downloadIMResourceToPath(ctx, runtime, "om_cancel", "file_cancel", "file", target, true)
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v, want context.Canceled", err)
|
||||
}
|
||||
@@ -525,7 +525,7 @@ func TestDownloadIMResourceToPathRangeDownload(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
target := filepath.Join("nested", "resource.bin")
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_range", "file_range", "file", target)
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_range", "file_range", "file", target, true)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -567,7 +567,7 @@ func TestDownloadIMResourceToPathInvalidContentRange(t *testing.T) {
|
||||
}))
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_bad", "file_bad", "file", "out.bin")
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_bad", "file_bad", "file", "out.bin", true)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid Content-Range header") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -596,7 +596,7 @@ func TestDownloadIMResourceToPathRangeChunkFailureCleansOutput(t *testing.T) {
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
target := "out.bin"
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_miderr", "file_miderr", "file", target)
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_miderr", "file_miderr", "file", target, true)
|
||||
if err == nil || !strings.Contains(err.Error(), "HTTP 500: chunk failed") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -622,7 +622,7 @@ func TestDownloadIMResourceToPathRangeOverflowCleansOutput(t *testing.T) {
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
target := "out.bin"
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_overflow", "file_overflow", "file", target)
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_overflow", "file_overflow", "file", target, true)
|
||||
if err == nil || !strings.Contains(err.Error(), "chunk overflow") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -658,7 +658,7 @@ func TestDownloadIMResourceToPathRangeShortChunkSizeMismatch(t *testing.T) {
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_short", "file_short", "file", "out.bin")
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_short", "file_short", "file", "out.bin", true)
|
||||
if err == nil || !strings.Contains(err.Error(), "file size mismatch") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -593,7 +593,7 @@ func TestDownloadIMResourceToPathHTTPClientError(t *testing.T) {
|
||||
return nil, errors.New("http client unavailable")
|
||||
}))
|
||||
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "img_123", "image", "out.bin")
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "img_123", "image", "out.bin", true)
|
||||
if err == nil || !strings.Contains(err.Error(), "http client unavailable") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -637,6 +637,68 @@ func TestParseTotalSize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseContentDispositionFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{name: "empty header", header: "", want: ""},
|
||||
{name: "no filename param", header: "attachment", want: ""},
|
||||
{name: "plain filename", header: `attachment; filename="report.xlsx"`, want: "report.xlsx"},
|
||||
{name: "unquoted filename", header: `attachment; filename=report.xlsx`, want: "report.xlsx"},
|
||||
{name: "RFC 5987 UTF-8 encoded", header: `attachment; filename*=UTF-8''%E5%AD%A3%E5%BA%A6%E6%8A%A5%E5%91%8A.xlsx`, want: "季度报告.xlsx"},
|
||||
{name: "RFC 5987 takes priority over plain", header: `attachment; filename="fallback.xlsx"; filename*=UTF-8''%E5%AD%A3%E5%BA%A6%E6%8A%A5%E5%91%8A.xlsx`, want: "季度报告.xlsx"},
|
||||
{name: "path traversal stripped", header: `attachment; filename="../../etc/passwd"`, want: "passwd"},
|
||||
{name: "windows path stripped", header: `attachment; filename="C:\\Windows\\evil.exe"`, want: "evil.exe"},
|
||||
{name: "control char rejected", header: "attachment; filename=\"evil\x01file.txt\"", want: ""},
|
||||
{name: "malformed header", header: "not/valid/mime; ===", want: ""},
|
||||
{name: "whitespace trimmed", header: `attachment; filename=" report.pdf "`, want: "report.pdf"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseContentDispositionFilename(tt.header); got != tt.want {
|
||||
t.Fatalf("parseContentDispositionFilename(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveIMResourceDownloadPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
safePath string
|
||||
contentType string
|
||||
contentDisposition string
|
||||
userSpecifiedOutput bool
|
||||
want string
|
||||
}{
|
||||
// safePath already has extension: always return as-is
|
||||
{name: "user path with ext, no CD", safePath: "out.xlsx", contentType: "application/pdf", userSpecifiedOutput: true, want: "out.xlsx"},
|
||||
{name: "user path with ext, CD present", safePath: "out.xlsx", contentDisposition: `attachment; filename="server.pdf"`, userSpecifiedOutput: true, want: "out.xlsx"},
|
||||
// No --output: use CD filename when present
|
||||
{name: "default path, CD filename", safePath: "file_xxx", contentDisposition: `attachment; filename="季度报告.xlsx"`, want: "季度报告.xlsx"},
|
||||
{name: "default path, CD RFC5987", safePath: "file_xxx", contentDisposition: `attachment; filename*=UTF-8''%E5%AD%A3%E5%BA%A6%E6%8A%A5%E5%91%8A.xlsx`, want: "季度报告.xlsx"},
|
||||
{name: "default path, no CD, MIME ext", safePath: "file_xxx", contentType: "application/pdf", want: "file_xxx.pdf"},
|
||||
{name: "default path, no CD, unknown MIME", safePath: "file_xxx", contentType: "application/x-unknown", want: "file_xxx"},
|
||||
{name: "default path, CD with dir component", safePath: "downloads/file_xxx", contentDisposition: `attachment; filename="report.xlsx"`, want: "downloads/report.xlsx"},
|
||||
// User --output without extension: use CD filename's extension
|
||||
{name: "user path no ext, CD with ext", safePath: "myfile", contentDisposition: `attachment; filename="server.pdf"`, userSpecifiedOutput: true, want: "myfile.pdf"},
|
||||
{name: "user path no ext, CD no ext, MIME ext", safePath: "myfile", contentDisposition: `attachment; filename="noext"`, contentType: "image/png", userSpecifiedOutput: true, want: "myfile.png"},
|
||||
{name: "user path no ext, no CD, MIME ext", safePath: "myfile", contentType: "image/jpeg", userSpecifiedOutput: true, want: "myfile.jpg"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveIMResourceDownloadPath(tt.safePath, tt.contentType, tt.contentDisposition, tt.userSpecifiedOutput)
|
||||
if got != tt.want {
|
||||
t.Fatalf("resolveIMResourceDownloadPath() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShortcuts(t *testing.T) {
|
||||
var commands []string
|
||||
for _, shortcut := range Shortcuts() {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -31,7 +32,7 @@ var ImMessagesResourcesDownload = common.Shortcut{
|
||||
{Name: "message-id", Desc: "message ID (om_xxx)", Required: true},
|
||||
{Name: "file-key", Desc: "resource key (img_xxx or file_xxx)", Required: true},
|
||||
{Name: "type", Desc: "resource type (image or file)", Required: true, Enum: []string{"image", "file"}},
|
||||
{Name: "output", Desc: "local save path (relative only, no .. traversal; defaults to file_key)"},
|
||||
{Name: "output", Desc: "local save path (relative only, no .. traversal); when omitted, uses the server's Content-Disposition filename if available, otherwise file_key; extension is inferred from Content-Disposition or Content-Type if not provided"},
|
||||
},
|
||||
DryRun: func(ctx context.Context, runtime *common.RuntimeContext) *common.DryRunAPI {
|
||||
fileKey := runtime.Str("file-key")
|
||||
@@ -72,7 +73,8 @@ var ImMessagesResourcesDownload = common.Shortcut{
|
||||
return output.ErrValidation("unsafe output path: %s", err)
|
||||
}
|
||||
|
||||
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, relPath)
|
||||
userSpecifiedOutput := runtime.Str("output") != ""
|
||||
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, relPath, userSpecifiedOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -262,18 +264,21 @@ func initialIMResourceDownloadHeaders(fileType string) map[string]string {
|
||||
}
|
||||
}
|
||||
|
||||
func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContext, messageID, fileKey, fileType, outputPath string) (string, int64, error) {
|
||||
func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContext, messageID, fileKey, fileType, outputPath string, userSpecifiedOutput bool) (string, int64, error) {
|
||||
downloadResp, err := doIMResourceDownloadRequest(ctx, runtime, messageID, fileKey, fileType, initialIMResourceDownloadHeaders(fileType))
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if downloadResp == nil {
|
||||
return "", 0, output.ErrNetwork("download failed: empty response")
|
||||
}
|
||||
|
||||
if downloadResp.StatusCode >= 400 {
|
||||
defer downloadResp.Body.Close()
|
||||
return "", 0, downloadResponseError(downloadResp)
|
||||
}
|
||||
|
||||
finalPath := resolveIMResourceDownloadPath(outputPath, downloadResp.Header.Get("Content-Type"))
|
||||
finalPath := resolveIMResourceDownloadPath(outputPath, downloadResp.Header.Get("Content-Type"), downloadResp.Header.Get("Content-Disposition"), userSpecifiedOutput)
|
||||
|
||||
var (
|
||||
body io.ReadCloser
|
||||
@@ -316,18 +321,63 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex
|
||||
return savedPath, result.Size(), nil
|
||||
}
|
||||
|
||||
func resolveIMResourceDownloadPath(safePath, contentType string) string {
|
||||
func resolveIMResourceDownloadPath(safePath, contentType, contentDisposition string, userSpecifiedOutput bool) string {
|
||||
if filepath.Ext(safePath) != "" {
|
||||
return safePath
|
||||
}
|
||||
mimeType := strings.Split(contentType, ";")[0]
|
||||
mimeType = strings.TrimSpace(mimeType)
|
||||
if cdFilename := parseContentDispositionFilename(contentDisposition); cdFilename != "" {
|
||||
if !userSpecifiedOutput {
|
||||
// No --output flag: use the original filename from the server.
|
||||
dir := filepath.Dir(safePath)
|
||||
if dir == "." {
|
||||
return cdFilename
|
||||
}
|
||||
return filepath.Join(dir, cdFilename)
|
||||
}
|
||||
// User specified a path without extension: append the extension from the CD filename.
|
||||
if ext := filepath.Ext(cdFilename); ext != "" {
|
||||
return safePath + ext
|
||||
}
|
||||
}
|
||||
mimeType := strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
if ext, ok := imMimeToExt[mimeType]; ok {
|
||||
return safePath + ext
|
||||
}
|
||||
return safePath
|
||||
}
|
||||
|
||||
// parseContentDispositionFilename extracts and sanitizes the filename from a
|
||||
// Content-Disposition header. It handles RFC 5987 encoded filenames (filename*)
|
||||
// with priority over plain filename via the standard mime package.
|
||||
// Returns an empty string if no valid filename can be extracted.
|
||||
func parseContentDispositionFilename(header string) string {
|
||||
if header == "" {
|
||||
return ""
|
||||
}
|
||||
_, params, err := mime.ParseMediaType(header)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
name := strings.TrimSpace(params["filename"])
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
// Strip any path component (Unix or Windows style) to prevent path traversal.
|
||||
if i := strings.LastIndexAny(name, "/\\"); i >= 0 {
|
||||
name = name[i+1:]
|
||||
}
|
||||
if name == "" || name == "." || name == ".." {
|
||||
return ""
|
||||
}
|
||||
// Reject control characters (including null bytes).
|
||||
for _, r := range name {
|
||||
if r < 0x20 || r == 0x7f {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func doIMResourceDownloadRequest(ctx context.Context, runtime *common.RuntimeContext, messageID, fileKey, fileType string, headers map[string]string) (*http.Response, error) {
|
||||
query := larkcore.QueryParams{}
|
||||
query.Set("type", fileType)
|
||||
|
||||
452
shortcuts/mail/draft/large_attachment_parse.go
Normal file
452
shortcuts/mail/draft/large_attachment_parse.go
Normal file
@@ -0,0 +1,452 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package draft
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
xhtml "golang.org/x/net/html"
|
||||
)
|
||||
|
||||
// largeAttHeaderEntry is a union of the CLI and server JSON formats for
|
||||
// entries in the large attachment header.
|
||||
type largeAttHeaderEntry struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
FileKey string `json:"file_key,omitempty"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
func (e largeAttHeaderEntry) token() string {
|
||||
if e.ID != "" {
|
||||
return e.ID
|
||||
}
|
||||
return e.FileKey
|
||||
}
|
||||
|
||||
// IsLargeAttachmentHeader returns true if the header name matches either
|
||||
// the CLI-written or server-returned large attachment header.
|
||||
func IsLargeAttachmentHeader(name string) bool {
|
||||
return strings.EqualFold(name, LargeAttachmentIDsHeader) ||
|
||||
strings.EqualFold(name, ServerLargeAttachmentHeader)
|
||||
}
|
||||
|
||||
// decodeLargeAttachmentHeader decodes the base64 value and returns entries.
|
||||
func decodeLargeAttachmentHeader(value string) ([]largeAttHeaderEntry, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var items []largeAttHeaderEntry
|
||||
if err := json.Unmarshal(decoded, &items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseLargeAttachmentTokens returns the ordered list of large attachment
|
||||
// tokens from either X-Lms-Large-Attachment-Ids (CLI format) or
|
||||
// X-Lark-Large-Attachment (server format). Returns nil when neither
|
||||
// header is present or the value is malformed.
|
||||
func parseLargeAttachmentTokens(headers []Header) []string {
|
||||
for _, h := range headers {
|
||||
if !IsLargeAttachmentHeader(h.Name) {
|
||||
continue
|
||||
}
|
||||
items, err := decodeLargeAttachmentHeader(h.Value)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(items))
|
||||
for _, it := range items {
|
||||
if tok := it.token(); tok != "" {
|
||||
out = append(out, tok)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseLargeAttachmentSummariesFromHeader extracts full metadata from the
|
||||
// large attachment header. Returns non-nil only when the server-format
|
||||
// header (X-Lark-Large-Attachment) is found, since it carries file_name
|
||||
// and file_size that the CLI-format header lacks.
|
||||
func ParseLargeAttachmentSummariesFromHeader(headers []Header) []LargeAttachmentSummary {
|
||||
for _, h := range headers {
|
||||
if !strings.EqualFold(h.Name, ServerLargeAttachmentHeader) {
|
||||
continue
|
||||
}
|
||||
items, err := decodeLargeAttachmentHeader(h.Value)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]LargeAttachmentSummary, 0, len(items))
|
||||
for _, it := range items {
|
||||
tok := it.token()
|
||||
if tok == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, LargeAttachmentSummary{
|
||||
Token: tok,
|
||||
FileName: it.FileName,
|
||||
SizeBytes: it.FileSize,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseLargeAttachmentItemsFromHTML walks the HTML body looking for large
|
||||
// attachment card items (<div id="large-file-item">) and returns a map
|
||||
// from token (data-mail-token attribute value) to filename + size.
|
||||
//
|
||||
// The size is parsed best-effort from the displayed string (e.g. "25.0 MB");
|
||||
// it carries the precision of the formatted value and is not byte-exact.
|
||||
func ParseLargeAttachmentItemsFromHTML(htmlBody string) map[string]LargeAttachmentSummary {
|
||||
out := map[string]LargeAttachmentSummary{}
|
||||
if htmlBody == "" {
|
||||
return out
|
||||
}
|
||||
doc, err := xhtml.Parse(strings.NewReader(htmlBody))
|
||||
if err != nil {
|
||||
return out
|
||||
}
|
||||
var walk func(n *xhtml.Node)
|
||||
walk = func(n *xhtml.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
if n.Type == xhtml.ElementNode && n.Data == "div" && attr(n, "id") == LargeFileItemID {
|
||||
if token, meta, ok := extractItemMeta(n); ok {
|
||||
out[token] = meta
|
||||
}
|
||||
// Do not descend further: the <a> and texts have been collected.
|
||||
return
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
walk(c)
|
||||
}
|
||||
}
|
||||
walk(doc)
|
||||
return out
|
||||
}
|
||||
|
||||
// extractItemMeta collects the token, filename, and size from a large
|
||||
// attachment item node. Returns ok=false when the token is missing.
|
||||
//
|
||||
// Expected structure (see largeAttItemTpl in mail/large_attachment.go):
|
||||
//
|
||||
// <div id="large-file-item">
|
||||
// <div><img ... /></div> // icon
|
||||
// <div>
|
||||
// <div>FILENAME</div>
|
||||
// <div><span>SIZE_DISPLAY</span></div>
|
||||
// </div>
|
||||
// <a data-mail-token="TOKEN" ...>DOWNLOAD_LABEL</a>
|
||||
// </div>
|
||||
//
|
||||
// The token comes from the <a data-mail-token=...>. The first non-anchor
|
||||
// text is the filename; the next text is the size display.
|
||||
func extractItemMeta(item *xhtml.Node) (token string, meta LargeAttachmentSummary, ok bool) {
|
||||
var texts []string
|
||||
var insideAnchor bool
|
||||
|
||||
var walk func(n *xhtml.Node)
|
||||
walk = func(n *xhtml.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
if n.Type == xhtml.ElementNode && n.Data == "a" {
|
||||
if t := attr(n, LargeAttachmentTokenAttr); t != "" && token == "" {
|
||||
token = t
|
||||
}
|
||||
// Skip collecting the anchor's label (e.g. "Download" / "下载").
|
||||
prev := insideAnchor
|
||||
insideAnchor = true
|
||||
defer func() { insideAnchor = prev }()
|
||||
}
|
||||
if n.Type == xhtml.TextNode && !insideAnchor {
|
||||
if s := strings.TrimSpace(n.Data); s != "" {
|
||||
texts = append(texts, s)
|
||||
}
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
walk(c)
|
||||
}
|
||||
}
|
||||
walk(item)
|
||||
|
||||
if token == "" {
|
||||
return "", LargeAttachmentSummary{}, false
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
meta.FileName = texts[0]
|
||||
}
|
||||
if len(texts) > 1 {
|
||||
meta.SizeBytes = parseSizeDisplay(texts[1])
|
||||
}
|
||||
return token, meta, true
|
||||
}
|
||||
|
||||
func attr(n *xhtml.Node, name string) string {
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == name {
|
||||
return a.Val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sizeDisplayRe matches sizes like "25.0 MB", "1 GB", "500 KB", "42 B".
|
||||
// The unit is case-insensitive and may be B / KB / MB / GB / TB.
|
||||
var sizeDisplayRe = regexp.MustCompile(`(?i)^\s*([0-9]+(?:\.[0-9]+)?)\s*(B|KB|MB|GB|TB)\s*$`)
|
||||
|
||||
// parseSizeDisplay converts a formatted size display string back into
|
||||
// an approximate byte count. Precision is limited by the display rounding
|
||||
// (e.g. "25.0 MB" round-trips to 26214400 bytes).
|
||||
// Returns 0 when the input cannot be parsed.
|
||||
func parseSizeDisplay(s string) int64 {
|
||||
m := sizeDisplayRe.FindStringSubmatch(s)
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
value, err := strconv.ParseFloat(m[1], 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
unit := strings.ToUpper(m[2])
|
||||
var mul int64
|
||||
switch unit {
|
||||
case "B":
|
||||
mul = 1
|
||||
case "KB":
|
||||
mul = 1024
|
||||
case "MB":
|
||||
mul = 1024 * 1024
|
||||
case "GB":
|
||||
mul = 1024 * 1024 * 1024
|
||||
case "TB":
|
||||
mul = 1024 * 1024 * 1024 * 1024
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
return int64(value * float64(mul))
|
||||
}
|
||||
|
||||
// removeLargeAttachment removes a large attachment by its file token.
|
||||
// It updates both representations:
|
||||
//
|
||||
// 1. X-Lms-Large-Attachment-Ids header: removes the token from the JSON
|
||||
// ID list. If the list becomes empty, the header itself is removed.
|
||||
// 2. HTML body: removes the <div id="large-file-item"> whose <a> has the
|
||||
// matching data-mail-token attribute. If the enclosing container
|
||||
// <div id="large-file-area-*"> has no remaining items, the whole
|
||||
// container is removed.
|
||||
func removeLargeAttachment(snapshot *DraftSnapshot, token string) error {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return fmt.Errorf("remove_attachment: token is empty")
|
||||
}
|
||||
if err := removeTokenFromIDsHeader(snapshot, token); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := removeTokenFromHTMLBody(snapshot, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeTokenFromIDsHeader removes the given token from whichever large
|
||||
// attachment header is present (CLI or server format). Returns an error
|
||||
// if no header is found or the token is not listed. After removal, the
|
||||
// header is re-encoded in CLI format (X-Lms-Large-Attachment-Ids) so
|
||||
// the server can process the update on upload.
|
||||
func removeTokenFromIDsHeader(snapshot *DraftSnapshot, token string) error {
|
||||
headerIdx := -1
|
||||
for i, h := range snapshot.Headers {
|
||||
if IsLargeAttachmentHeader(h.Name) {
|
||||
headerIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if headerIdx < 0 {
|
||||
return fmt.Errorf("remove_attachment: draft has no large attachment header")
|
||||
}
|
||||
items, err := decodeLargeAttachmentHeader(snapshot.Headers[headerIdx].Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove_attachment: malformed large attachment header: %w", err)
|
||||
}
|
||||
filtered := make([]largeAttHeaderEntry, 0, len(items))
|
||||
removed := false
|
||||
for _, it := range items {
|
||||
if it.token() == token {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, it)
|
||||
}
|
||||
if !removed {
|
||||
return fmt.Errorf("remove_attachment: token %q not found in large attachment header", token)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
snapshot.Headers = append(snapshot.Headers[:headerIdx], snapshot.Headers[headerIdx+1:]...)
|
||||
return nil
|
||||
}
|
||||
cliItems := make([]struct {
|
||||
ID string `json:"id"`
|
||||
}, len(filtered))
|
||||
for i, it := range filtered {
|
||||
cliItems[i].ID = it.token()
|
||||
}
|
||||
encoded, err := json.Marshal(cliItems)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove_attachment: failed to re-encode large attachment header: %w", err)
|
||||
}
|
||||
snapshot.Headers[headerIdx].Name = LargeAttachmentIDsHeader
|
||||
snapshot.Headers[headerIdx].Value = base64.StdEncoding.EncodeToString(encoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeTokenFromHTMLBody walks the HTML body, removes the single
|
||||
// large-file-item whose anchor has data-mail-token == token, and if the
|
||||
// enclosing container becomes empty (no more large-file-item children),
|
||||
// removes the whole container.
|
||||
//
|
||||
// It is not an error if the HTML body or item is missing — the header
|
||||
// removal is still considered the authoritative operation. This handles
|
||||
// cases where the HTML was already edited out but the header wasn't.
|
||||
func removeTokenFromHTMLBody(snapshot *DraftSnapshot, token string) error {
|
||||
htmlPart := FindHTMLBodyPart(snapshot.Body)
|
||||
if htmlPart == nil || len(htmlPart.Body) == 0 {
|
||||
return nil
|
||||
}
|
||||
body := string(htmlPart.Body)
|
||||
newBody, changed := RemoveLargeFileItemFromHTML(body, token)
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
htmlPart.Body = []byte(newBody)
|
||||
htmlPart.Dirty = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveLargeFileItemFromHTML parses the HTML, finds the large-file-item
|
||||
// containing an <a> whose token matches (via data-mail-token attribute or
|
||||
// href URL token= parameter), removes that item, and if the enclosing
|
||||
// large-file-area container becomes empty, removes the container as well.
|
||||
// Returns the updated HTML and a changed flag.
|
||||
func RemoveLargeFileItemFromHTML(htmlBody, token string) (string, bool) {
|
||||
doc, err := xhtml.Parse(strings.NewReader(htmlBody))
|
||||
if err != nil {
|
||||
return htmlBody, false
|
||||
}
|
||||
item := findLargeFileItemByToken(doc, token)
|
||||
if item == nil {
|
||||
return htmlBody, false
|
||||
}
|
||||
container := item.Parent
|
||||
// Detach the item from its parent.
|
||||
if container != nil {
|
||||
container.RemoveChild(item)
|
||||
}
|
||||
// If the container is a large-file-area and has no remaining
|
||||
// large-file-item children, remove the whole container.
|
||||
if container != nil && isLargeFileAreaContainer(container) && !hasLargeFileItemChild(container) {
|
||||
if grand := container.Parent; grand != nil {
|
||||
grand.RemoveChild(container)
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := xhtml.Render(&buf, doc); err != nil {
|
||||
return htmlBody, false
|
||||
}
|
||||
return stripHTMLEnvelope(buf.String()), true
|
||||
}
|
||||
|
||||
func findLargeFileItemByToken(n *xhtml.Node, token string) *xhtml.Node {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
if n.Type == xhtml.ElementNode && n.Data == "div" && attr(n, "id") == LargeFileItemID {
|
||||
if itemContainsToken(n, token) {
|
||||
return n
|
||||
}
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
if found := findLargeFileItemByToken(c, token); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func itemContainsToken(item *xhtml.Node, token string) bool {
|
||||
if item == nil {
|
||||
return false
|
||||
}
|
||||
for c := item.FirstChild; c != nil; c = c.NextSibling {
|
||||
if c.Type == xhtml.ElementNode && c.Data == "a" {
|
||||
if attr(c, LargeAttachmentTokenAttr) == token {
|
||||
return true
|
||||
}
|
||||
if hrefContainsToken(attr(c, "href"), token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if itemContainsToken(c, token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hrefContainsToken(href, token string) bool {
|
||||
if href == "" || token == "" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(href)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return u.Query().Get("token") == token
|
||||
}
|
||||
|
||||
func isLargeFileAreaContainer(n *xhtml.Node) bool {
|
||||
if n == nil || n.Type != xhtml.ElementNode || n.Data != "div" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(attr(n, "id"), LargeFileContainerIDPrefix)
|
||||
}
|
||||
|
||||
func hasLargeFileItemChild(n *xhtml.Node) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
if c.Type == xhtml.ElementNode && c.Data == "div" && attr(c, "id") == LargeFileItemID {
|
||||
return true
|
||||
}
|
||||
if hasLargeFileItemChild(c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripHTMLEnvelope removes the <html><head></head><body>...</body></html>
|
||||
// wrapper that xhtml.Parse + xhtml.Render adds around HTML fragments.
|
||||
func stripHTMLEnvelope(s string) string {
|
||||
s = strings.TrimPrefix(s, "<html><head></head><body>")
|
||||
s = strings.TrimSuffix(s, "</body></html>")
|
||||
return s
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user