From 4a4c3344c8846c0d70a4ee81f0e4173be9cd868e Mon Sep 17 00:00:00 2001 From: "guokexin.02" <264159873+Tantanz20020918@users.noreply.github.com> Date: Wed, 17 Jun 2026 17:41:48 +0800 Subject: [PATCH] fix: align api success envelopes (#1489) --- cmd/api/api.go | 64 ++-- cmd/api/api_test.go | 312 ++++++++++++++++++- cmd/service/service.go | 49 ++- cmd/service/service_test.go | 373 ++++++++++++++++++++++- errs/subtypes.go | 1 + internal/client/client.go | 23 +- internal/client/client_test.go | 58 +++- internal/client/pagination.go | 28 -- internal/client/response.go | 30 +- internal/client/response_test.go | 57 +++- internal/output/emit.go | 18 +- internal/output/emit_test.go | 17 +- internal/output/envelope_success.go | 58 ++++ internal/output/envelope_success_test.go | 173 +++++++++++ shortcuts/common/runner.go | 5 +- 15 files changed, 1163 insertions(+), 103 deletions(-) create mode 100644 internal/output/envelope_success.go create mode 100644 internal/output/envelope_success_test.go diff --git a/cmd/api/api.go b/cmd/api/api.go index 072117d7..58b4b5d1 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -233,7 +233,7 @@ func apiRun(opts *APIOptions) error { } if opts.PageAll { - return apiPaginate(opts.Ctx, ac, request, format, opts.JqExpr, out, f.IOStreams.ErrOut, + return apiPaginate(opts.Ctx, ac, request, format, opts.JqExpr, out, f.IOStreams.ErrOut, opts.Cmd.CommandPath(), client.PaginationOptions{PageLimit: opts.PageLimit, PageDelay: opts.PageDelay}) } @@ -272,24 +272,13 @@ func apiDryRun(f *cmdutil.Factory, request client.RawApiRequest, config *core.Cl return cmdutil.PrintDryRun(f.IOStreams.Out, request, config, format) } -func apiPaginate(ctx context.Context, ac *client.APIClient, request client.RawApiRequest, format output.Format, jqExpr string, out, errOut io.Writer, pagOpts client.PaginationOptions) error { +func apiPaginate(ctx context.Context, ac *client.APIClient, request client.RawApiRequest, format output.Format, jqExpr string, out, errOut io.Writer, commandPath string, pagOpts client.PaginationOptions) error { if pagOpts.Identity == "" { pagOpts.Identity = request.As } // When jq is set, always aggregate all pages then filter. if jqExpr != "" { - if err := client.PaginateWithJq(ctx, ac, request, jqExpr, out, pagOpts, ac.CheckResponse); err != nil { - return output.MarkRaw(err) - } - return nil - } - - switch format { - case output.FormatNDJSON, output.FormatTable, output.FormatCSV: - pf := output.NewPaginatedFormatter(out, format) - result, hasItems, err := ac.StreamPages(ctx, request, func(items []interface{}) { - pf.FormatPage(items) - }, pagOpts) + result, err := ac.PaginateAll(ctx, request, pagOpts) if err != nil { return output.MarkRaw(err) } @@ -297,9 +286,46 @@ func apiPaginate(ctx context.Context, ac *client.APIClient, request client.RawAp output.FormatValue(out, result, output.FormatJSON) return output.MarkRaw(apiErr) } + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + JqExpr: jqExpr, + Out: out, + ErrOut: errOut, + }) + } + + switch format { + case output.FormatNDJSON, output.FormatTable, output.FormatCSV: + pf := output.NewPaginatedFormatter(out, format) + result, hasItems, err := ac.StreamPages(ctx, request, func(items []interface{}) error { + // Streaming formats intentionally emit each page after that page has + // passed safety scanning. A later page may still fail, so callers + // must use the exit code to distinguish complete vs partial output. + scanResult := output.ScanForSafety(commandPath, items, errOut) + if scanResult.Blocked { + return scanResult.BlockErr + } + if scanResult.Alert != nil { + output.WriteAlertWarning(errOut, scanResult.Alert) + } + pf.FormatPage(items) + return nil + }, pagOpts) + if err != nil { + return output.MarkRaw(err) + } + if apiErr := ac.CheckResponse(result, pagOpts.Identity); apiErr != nil { + return output.MarkRaw(apiErr) + } if !hasItems { fmt.Fprintf(errOut, "warning: this API does not return a list, format %q is not supported, falling back to json\n", format) - output.FormatValue(out, result, output.FormatJSON) + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + Out: out, + ErrOut: errOut, + }) } return nil default: @@ -311,7 +337,11 @@ func apiPaginate(ctx context.Context, ac *client.APIClient, request client.RawAp output.FormatValue(out, result, output.FormatJSON) return output.MarkRaw(apiErr) } - output.FormatValue(out, result, format) - return nil + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + Out: out, + ErrOut: errOut, + }) } } diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go index 393e2542..fa91d891 100644 --- a/cmd/api/api_test.go +++ b/cmd/api/api_test.go @@ -4,6 +4,8 @@ package api import ( + "context" + "encoding/json" "errors" "os" "sort" @@ -11,6 +13,7 @@ import ( "testing" "github.com/larksuite/cli/errs" + extcs "github.com/larksuite/cli/extension/contentsafety" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/httpmock" @@ -101,8 +104,19 @@ func TestApiCmd_BotMode(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !strings.Contains(stdout.String(), "success") { - t.Error("expected 'success' in output") + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + if got["ok"] != true || got["identity"] != "bot" { + t.Fatalf("unexpected envelope: %#v", got) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("success envelope leaked outer code: %s", stdout.String()) + } + data, ok := got["data"].(map[string]interface{}) + if !ok || data["result"] != "success" { + t.Fatalf("data = %#v, want result=success", got["data"]) } } @@ -328,8 +342,16 @@ func TestApiCmd_PageAll_NonBatchAPI_FallbackToJSON(t *testing.T) { t.Error("expected 'falling back to json' in stderr") } // Should output JSON result to stdout - if !strings.Contains(stdout.String(), "u123") { - t.Error("expected user_id in JSON output") + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + data, ok := got["data"].(map[string]interface{}) + if got["ok"] != true || got["identity"] != "bot" || !ok || data["user_id"] != "u123" { + t.Fatalf("unexpected fallback envelope: %#v", got) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("fallback success envelope leaked outer code: %s", stdout.String()) } } @@ -342,7 +364,7 @@ func TestApiCmd_PageAll_NonBatchAPI_ErrorStillOutputsJSON(t *testing.T) { reg.Register(&httpmock.Stub{ URL: "/open-apis/im/v1/chats/oc_xxx/announcement", Body: map[string]interface{}{ - "code": 230001, "msg": "no permission", + "code": 230027, "msg": "user not authorized", }, }) @@ -354,12 +376,20 @@ func TestApiCmd_PageAll_NonBatchAPI_ErrorStillOutputsJSON(t *testing.T) { t.Fatal("expected an error for non-zero code") } // Should still output the response body so user can see the error details - if !strings.Contains(stdout.String(), "230001") { + if !strings.Contains(stdout.String(), "230027") { t.Errorf("expected error response in stdout, got: %s", stdout.String()) } - if !strings.Contains(stdout.String(), "no permission") { + if !strings.Contains(stdout.String(), "user not authorized") { t.Errorf("expected error message in stdout, got: %s", stdout.String()) } + if strings.Contains(stdout.String(), `"ok": true`) || strings.Contains(stdout.String(), `"ok":true`) { + t.Fatalf("unexpected success envelope on error path: %s", stdout.String()) + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + var permErr *errs.PermissionError + if !errors.As(err, &permErr) { + t.Fatalf("expected PermissionError, got %T: %v", err, err) + } } func TestApiCmd_PageAll_BatchAPI_StreamsItems(t *testing.T) { @@ -395,6 +425,274 @@ func TestApiCmd_PageAll_BatchAPI_StreamsItems(t *testing.T) { } } +func TestApiCmd_PageAll_StreamBusinessErrorDoesNotDumpJSON(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-pageall-stream-err", AppSecret: "test-secret-pageall-stream-err", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "safe-page"}}, + "has_more": true, + "page_token": "next", + }, + }, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 230027, "msg": "user not authorized", + }, + }) + + cmd := NewCmdApi(f, nil) + cmd.SetArgs([]string{"GET", "/open-apis/contact/v3/users", "--as", "bot", "--page-all", "--format", "ndjson"}) + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for non-zero code on later page") + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + out := stdout.String() + if !strings.Contains(out, "safe-page") { + t.Fatalf("expected earlier successful page to remain streamed, got: %s", out) + } + if strings.Contains(out, "230027") || strings.Contains(out, "user not authorized") { + t.Fatalf("streaming stdout should not contain raw error JSON, got: %s", out) + } + if strings.Contains(out, "\n \"code\"") { + t.Fatalf("streaming stdout should not contain indented JSON error dump, got: %s", out) + } +} + +func TestApiCmd_PageAll_BatchAPI_DefaultJSONEnvelope(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-pageall-json", AppSecret: "test-secret-pageall-json", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }, + }) + + cmd := NewCmdApi(f, nil) + cmd.SetArgs([]string{"GET", "/open-apis/contact/v3/users", "--as", "bot", "--page-all"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + data, ok := got["data"].(map[string]interface{}) + if got["ok"] != true || got["identity"] != "bot" || !ok { + t.Fatalf("unexpected envelope: %#v", got) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("success envelope leaked outer code: %s", stdout.String()) + } + items, ok := data["items"].([]interface{}) + if !ok || len(items) != 1 { + t.Fatalf("data.items = %#v, want one item", data["items"]) + } +} + +type apiContentSafetyProvider struct { + called bool + path string + data interface{} + match string +} + +func (p *apiContentSafetyProvider) Name() string { return "api-test" } + +func (p *apiContentSafetyProvider) Scan(_ context.Context, req extcs.ScanRequest) (*extcs.Alert, error) { + p.called = true + p.path = req.Path + p.data = req.Data + if p.match != "" { + b, _ := json.Marshal(req.Data) + if !strings.Contains(string(b), p.match) { + return nil, nil + } + } + return &extcs.Alert{Provider: "api-test", MatchedRules: []string{"pagination"}}, nil +} + +func TestApiCmd_PageAll_DefaultJSONRunsContentSafety(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + provider := &apiContentSafetyProvider{} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-pageall-safety", AppSecret: "test-secret-pageall-safety", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }, + }) + + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdApi(f, nil)) + root.SetArgs([]string{"api", "GET", "/open-apis/contact/v3/users", "--as", "bot", "--page-all"}) + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !provider.called { + t.Fatal("expected content safety provider to scan paginated output") + } + if provider.path != "api" { + t.Fatalf("scan path = %q, want api", provider.path) + } + data, ok := provider.data.(map[string]interface{}) + if !ok { + t.Fatalf("scanned data type = %T, want map", provider.data) + } + if _, hasCode := data["code"]; hasCode { + t.Fatalf("scanned data should be business data only, got %#v", data) + } + + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + alert, ok := got["_content_safety_alert"].(map[string]interface{}) + if !ok || alert["provider"] != "api-test" { + t.Fatalf("missing content safety alert in envelope: %#v", got) + } +} + +func TestApiCmd_PageAll_StreamFormatRunsContentSafety(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + provider := &apiContentSafetyProvider{} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-pageall-stream-safety", AppSecret: "test-secret-pageall-stream-safety", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }, + }) + + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdApi(f, nil)) + root.SetArgs([]string{"api", "GET", "/open-apis/contact/v3/users", "--as", "bot", "--page-all", "--format", "ndjson"}) + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !provider.called { + t.Fatal("expected content safety provider to scan streamed paginated output") + } + if provider.path != "api" { + t.Fatalf("scan path = %q, want api", provider.path) + } + items, ok := provider.data.([]interface{}) + if !ok || len(items) != 1 { + t.Fatalf("scanned data = %#v, want one streamed item", provider.data) + } + if !strings.Contains(stderr.String(), "warning: content safety alert from api-test") { + t.Fatalf("expected content safety warning on stderr, got: %s", stderr.String()) + } + if !strings.Contains(stdout.String(), `"id":"1"`) { + t.Fatalf("expected streamed ndjson output, got: %s", stdout.String()) + } +} + +func TestApiCmd_PageAll_StreamFormatBlockSkipsBlockedPage(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + provider := &apiContentSafetyProvider{match: "blocked"} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-pageall-stream-block", AppSecret: "test-secret-pageall-stream-block", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "safe-page"}}, + "has_more": true, + "page_token": "next", + }, + }, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/contact/v3/users", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "blocked-page"}}, + "has_more": false, + }, + }, + }) + + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdApi(f, nil)) + root.SetArgs([]string{"api", "GET", "/open-apis/contact/v3/users", "--as", "bot", "--page-all", "--format", "ndjson"}) + err := root.Execute() + if err == nil { + t.Fatal("expected content safety block error") + } + var safetyErr *errs.ContentSafetyError + if !errors.As(err, &safetyErr) { + t.Fatalf("expected ContentSafetyError, got %T: %v", err, err) + } + if safetyErr.Category != errs.CategoryPolicy || safetyErr.Subtype != errs.SubtypeContentSafety { + t.Fatalf("problem = %s/%s, want %s/%s", safetyErr.Category, safetyErr.Subtype, errs.CategoryPolicy, errs.SubtypeContentSafety) + } + if len(safetyErr.Rules) != 1 || safetyErr.Rules[0] != "pagination" { + t.Fatalf("rules = %v, want [pagination]", safetyErr.Rules) + } + out := stdout.String() + if !strings.Contains(out, "safe-page") { + t.Fatalf("expected earlier safe page to remain streamed, got: %s", out) + } + if strings.Contains(out, "blocked-page") { + t.Fatalf("blocked page was written before safety block: %s", out) + } +} + +func requireProblem(t *testing.T, err error, category errs.Category, subtype errs.Subtype, code int) { + t.Helper() + p, ok := errs.ProblemOf(err) + if !ok { + t.Fatalf("expected typed error, got %T: %v", err, err) + } + if p.Category != category || p.Subtype != subtype || p.Code != code { + t.Fatalf("problem = %s/%s/%d, want %s/%s/%d", p.Category, p.Subtype, p.Code, category, subtype, code) + } +} + func TestNormalisePath_StripsQueryAndFragment(t *testing.T) { for _, tt := range []struct { name string diff --git a/cmd/service/service.go b/cmd/service/service.go index ccec5d97..e6344dd7 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -387,7 +387,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { checkErr := ac.CheckResponse if opts.PageAll { - return servicePaginate(opts.Ctx, ac, request, format, opts.JqExpr, out, f.IOStreams.ErrOut, + return servicePaginate(opts.Ctx, ac, request, format, opts.JqExpr, out, f.IOStreams.ErrOut, opts.Cmd.CommandPath(), client.PaginationOptions{PageLimit: opts.PageLimit, PageDelay: opts.PageDelay}, checkErr) } @@ -627,20 +627,45 @@ func serviceDryRun(f *cmdutil.Factory, request client.RawApiRequest, config *cor return cmdutil.PrintDryRun(f.IOStreams.Out, request, config, format) } -func servicePaginate(ctx context.Context, ac *client.APIClient, request client.RawApiRequest, format output.Format, jqExpr string, out, errOut io.Writer, pagOpts client.PaginationOptions, checkErr func(interface{}, core.Identity) error) error { +func servicePaginate(ctx context.Context, ac *client.APIClient, request client.RawApiRequest, format output.Format, jqExpr string, out, errOut io.Writer, commandPath string, pagOpts client.PaginationOptions, checkErr func(interface{}, core.Identity) error) error { if pagOpts.Identity == "" { pagOpts.Identity = request.As } // When jq is set, always aggregate all pages then filter. if jqExpr != "" { - return client.PaginateWithJq(ctx, ac, request, jqExpr, out, pagOpts, checkErr) + result, err := ac.PaginateAll(ctx, request, pagOpts) + if err != nil { + return err + } + if apiErr := checkErr(result, pagOpts.Identity); apiErr != nil { + output.FormatValue(out, result, output.FormatJSON) + return apiErr + } + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + JqExpr: jqExpr, + Out: out, + ErrOut: errOut, + }) } switch format { case output.FormatNDJSON, output.FormatTable, output.FormatCSV: pf := output.NewPaginatedFormatter(out, format) - result, hasItems, err := ac.StreamPages(ctx, request, func(items []interface{}) { + result, hasItems, err := ac.StreamPages(ctx, request, func(items []interface{}) error { + // Streaming formats intentionally emit each page after that page has + // passed safety scanning. A later page may still fail, so callers + // must use the exit code to distinguish complete vs partial output. + scanResult := output.ScanForSafety(commandPath, items, errOut) + if scanResult.Blocked { + return scanResult.BlockErr + } + if scanResult.Alert != nil { + output.WriteAlertWarning(errOut, scanResult.Alert) + } pf.FormatPage(items) + return nil }, pagOpts) if err != nil { return err @@ -650,7 +675,12 @@ func servicePaginate(ctx context.Context, ac *client.APIClient, request client.R } if !hasItems { fmt.Fprintf(errOut, "warning: this API does not return a list, format %q is not supported, falling back to json\n", format) - output.FormatValue(out, result, output.FormatJSON) + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + Out: out, + ErrOut: errOut, + }) } return nil default: @@ -659,9 +689,14 @@ func servicePaginate(ctx context.Context, ac *client.APIClient, request client.R return err } if apiErr := checkErr(result, pagOpts.Identity); apiErr != nil { + output.FormatValue(out, result, output.FormatJSON) return apiErr } - output.FormatValue(out, result, format) - return nil + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: commandPath, + Identity: string(pagOpts.Identity), + Out: out, + ErrOut: errOut, + }) } } diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go index f1f48ea5..70d5658d 100644 --- a/cmd/service/service_test.go +++ b/cmd/service/service_test.go @@ -4,10 +4,15 @@ package service import ( + "context" + "encoding/json" + "errors" "os" "strings" "testing" + "github.com/larksuite/cli/errs" + extcs "github.com/larksuite/cli/extension/contentsafety" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/httpmock" @@ -407,8 +412,19 @@ func TestServiceMethod_BotMode_Success(t *testing.T) { if err := cmd.Execute(); err != nil { t.Fatalf("unexpected error: %v", err) } - if !strings.Contains(stdout.String(), "success") { - t.Errorf("expected 'success' in output, got:\n%s", stdout.String()) + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + if got["ok"] != true || got["identity"] != "bot" { + t.Fatalf("unexpected envelope: %#v", got) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("success envelope leaked outer code: %s", stdout.String()) + } + data, ok := got["data"].(map[string]interface{}) + if !ok || data["result"] != "success" { + t.Fatalf("data = %#v, want result=success", got["data"]) } } @@ -436,8 +452,312 @@ func TestServiceMethod_BotMode_PageAll_JSON(t *testing.T) { if err := cmd.Execute(); err != nil { t.Fatalf("unexpected error: %v", err) } - if !strings.Contains(stdout.String(), `"id"`) { - t.Errorf("expected items in output, got:\n%s", stdout.String()) + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + data, ok := got["data"].(map[string]interface{}) + if got["ok"] != true || got["identity"] != "bot" || !ok { + t.Fatalf("unexpected envelope: %#v", got) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("success envelope leaked outer code: %s", stdout.String()) + } + items, ok := data["items"].([]interface{}) + if !ok || len(items) != 1 { + t.Fatalf("data.items = %#v, want one item", data["items"]) + } +} + +type serviceContentSafetyProvider struct { + called bool + path string + data interface{} + match string +} + +func (p *serviceContentSafetyProvider) Name() string { return "service-test" } + +func (p *serviceContentSafetyProvider) Scan(_ context.Context, req extcs.ScanRequest) (*extcs.Alert, error) { + p.called = true + p.path = req.Path + p.data = req.Data + if p.match != "" { + b, _ := json.Marshal(req.Data) + if !strings.Contains(string(b), p.match) { + return nil, nil + } + } + return &extcs.Alert{Provider: "service-test", MatchedRules: []string{"pagination"}}, nil +} + +func TestServiceMethod_PageAll_DefaultJSONRunsContentSafety(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + provider := &serviceContentSafetyProvider{} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-safety", AppSecret: "test-secret-service-safety", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdServiceMethod(f, spec, method, "list", "items", nil)) + root.SetArgs([]string{"list", "--as", "bot", "--page-all"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !provider.called { + t.Fatal("expected content safety provider to scan paginated output") + } + if provider.path != "list" { + t.Fatalf("scan path = %q, want list", provider.path) + } + data, ok := provider.data.(map[string]interface{}) + if !ok { + t.Fatalf("scanned data type = %T, want map", provider.data) + } + if _, hasCode := data["code"]; hasCode { + t.Fatalf("scanned data should be business data only, got %#v", data) + } + + var got map[string]interface{} + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, stdout.String()) + } + alert, ok := got["_content_safety_alert"].(map[string]interface{}) + if !ok || alert["provider"] != "service-test" { + t.Fatalf("missing content safety alert in envelope: %#v", got) + } +} + +func TestServiceMethod_PageAll_StreamFormatRunsContentSafety(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + provider := &serviceContentSafetyProvider{} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-stream-safety", AppSecret: "test-secret-service-stream-safety", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdServiceMethod(f, spec, method, "list", "items", nil)) + root.SetArgs([]string{"list", "--as", "bot", "--page-all", "--format", "ndjson"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !provider.called { + t.Fatal("expected content safety provider to scan streamed paginated output") + } + if provider.path != "list" { + t.Fatalf("scan path = %q, want list", provider.path) + } + items, ok := provider.data.([]interface{}) + if !ok || len(items) != 1 { + t.Fatalf("scanned data = %#v, want one streamed item", provider.data) + } + if !strings.Contains(stderr.String(), "warning: content safety alert from service-test") { + t.Fatalf("expected content safety warning on stderr, got: %s", stderr.String()) + } + if !strings.Contains(stdout.String(), `"id":"1"`) { + t.Fatalf("expected streamed ndjson output, got: %s", stdout.String()) + } +} + +func TestServiceMethod_PageAll_StreamFormatBlockSkipsBlockedPage(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + provider := &serviceContentSafetyProvider{match: "blocked"} + extcs.Register(provider) + t.Cleanup(func() { extcs.Register(nil) }) + + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-stream-block", AppSecret: "test-secret-service-stream-block", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "safe-page"}}, + "has_more": true, + "page_token": "next", + }, + }, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "blocked-page"}}, + "has_more": false, + }, + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + root := &cobra.Command{Use: "lark-cli"} + root.AddCommand(NewCmdServiceMethod(f, spec, method, "list", "items", nil)) + root.SetArgs([]string{"list", "--as", "bot", "--page-all", "--format", "ndjson"}) + + err := root.Execute() + if err == nil { + t.Fatal("expected content safety block error") + } + var safetyErr *errs.ContentSafetyError + if !errors.As(err, &safetyErr) { + t.Fatalf("expected ContentSafetyError, got %T: %v", err, err) + } + if safetyErr.Category != errs.CategoryPolicy || safetyErr.Subtype != errs.SubtypeContentSafety { + t.Fatalf("problem = %s/%s, want %s/%s", safetyErr.Category, safetyErr.Subtype, errs.CategoryPolicy, errs.SubtypeContentSafety) + } + if len(safetyErr.Rules) != 1 || safetyErr.Rules[0] != "pagination" { + t.Fatalf("rules = %v, want [pagination]", safetyErr.Rules) + } + out := stdout.String() + if !strings.Contains(out, "safe-page") { + t.Fatalf("expected earlier safe page to remain streamed, got: %s", out) + } + if strings.Contains(out, "blocked-page") { + t.Fatalf("blocked page was written before safety block: %s", out) + } +} + +func TestServiceMethod_BusinessErrorReturnsTypedErrorWithoutSuccessEnvelope(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-err", AppSecret: "test-secret-service-err", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 230027, "msg": "user not authorized", + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + cmd := NewCmdServiceMethod(f, spec, method, "list", "items", nil) + cmd.SetArgs([]string{"--as", "bot"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for non-zero code") + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + var permErr *errs.PermissionError + if !errors.As(err, &permErr) { + t.Fatalf("expected PermissionError, got %T: %v", err, err) + } + if strings.Contains(stdout.String(), `"ok": true`) || strings.Contains(stdout.String(), `"ok":true`) { + t.Fatalf("unexpected success envelope on error path: %s", stdout.String()) + } +} + +func TestServiceMethod_PageAll_DefaultBusinessErrorOutputsRawResponse(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-pageall-err", AppSecret: "test-secret-service-pageall-err", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 230027, "msg": "user not authorized", + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + cmd := NewCmdServiceMethod(f, spec, method, "list", "items", nil) + cmd.SetArgs([]string{"--as", "bot", "--page-all"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for non-zero code") + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + if !strings.Contains(stdout.String(), "230027") || !strings.Contains(stdout.String(), "user not authorized") { + t.Fatalf("expected raw error response on stdout, got: %s", stdout.String()) + } + if strings.Contains(stdout.String(), `"ok": true`) || strings.Contains(stdout.String(), `"ok":true`) { + t.Fatalf("unexpected success envelope on error path: %s", stdout.String()) + } +} + +func TestServiceMethod_PageAll_StreamBusinessErrorDoesNotDumpJSON(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-service-pageall-stream-err", AppSecret: "test-secret-service-pageall-stream-err", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "safe-page"}}, + "has_more": true, + "page_token": "next", + }, + }, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 230027, + "msg": "user not authorized", + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + cmd := NewCmdServiceMethod(f, spec, method, "list", "items", nil) + cmd.SetArgs([]string{"--as", "bot", "--page-all", "--format", "ndjson"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for non-zero code") + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + out := stdout.String() + if !strings.Contains(out, "safe-page") { + t.Fatalf("expected earlier successful page to remain streamed, got: %s", out) + } + if strings.Contains(out, "230027") || strings.Contains(out, "user not authorized") { + t.Fatalf("streaming stdout should not contain raw error JSON, got: %s", out) + } + if strings.Contains(out, "\n \"code\"") { + t.Fatalf("streaming stdout should not contain indented JSON error dump, got: %s", out) } } @@ -629,6 +949,51 @@ func TestServiceMethod_PageAll_WithJq(t *testing.T) { } } +func TestServiceMethod_PageAll_WithJqBusinessErrorOutputsRawResponse(t *testing.T) { + f, stdout, _, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app-spjq-err", AppSecret: "test-secret-spjq-err", Brand: core.BrandFeishu, + }) + + reg.Register(&httpmock.Stub{ + URL: "/open-apis/svc/v1/items", + Body: map[string]interface{}{ + "code": 230027, "msg": "user not authorized", + }, + }) + + spec := meta.ServiceFromMap(map[string]interface{}{"name": "svc", "servicePath": "/open-apis/svc/v1"}) + method := meta.FromMap(map[string]interface{}{"path": "items", "httpMethod": "GET", "parameters": map[string]interface{}{}}) + cmd := NewCmdServiceMethod(f, spec, method, "list", "items", nil) + cmd.SetArgs([]string{"--as", "bot", "--page-all", "--jq", ".data.items[].id"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error for non-zero code") + } + requireProblem(t, err, errs.CategoryAuthorization, errs.SubtypeUserUnauthorized, 230027) + var permErr *errs.PermissionError + if !errors.As(err, &permErr) { + t.Fatalf("expected PermissionError, got %T: %v", err, err) + } + if !strings.Contains(stdout.String(), "230027") || !strings.Contains(stdout.String(), "user not authorized") { + t.Fatalf("expected raw error response on stdout, got: %s", stdout.String()) + } + if strings.Contains(stdout.String(), `"ok": true`) || strings.Contains(stdout.String(), `"ok":true`) { + t.Fatalf("unexpected success envelope on error path: %s", stdout.String()) + } +} + +func requireProblem(t *testing.T, err error, category errs.Category, subtype errs.Subtype, code int) { + t.Helper() + p, ok := errs.ProblemOf(err) + if !ok { + t.Fatalf("expected typed error, got %T: %v", err, err) + } + if p.Category != category || p.Subtype != subtype || p.Code != code { + t.Fatalf("problem = %s/%s/%d, want %s/%s/%d", p.Category, p.Subtype, p.Code, category, subtype, code) + } +} + // ── file upload ── func imImageMethod() meta.Method { diff --git a/errs/subtypes.go b/errs/subtypes.go index 913170d6..df0cf8ab 100644 --- a/errs/subtypes.go +++ b/errs/subtypes.go @@ -73,6 +73,7 @@ const ( const ( SubtypeChallengeRequired Subtype = "challenge_required" // user must complete browser challenge / MFA SubtypeAccessDenied Subtype = "access_denied" // policy denies access outright + SubtypeContentSafety Subtype = "content_safety" // content-safety scanner blocked output in block mode ) // CategoryInternal subtypes diff --git a/internal/client/client.go b/internal/client/client.go index dc4a0e89..be3acbc4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -350,7 +350,7 @@ func (c *APIClient) CallAPI(ctx context.Context, request RawApiRequest) (interfa // paginateLoop runs the core pagination loop. For each successful page (code == 0), // it calls onResult if non-nil. It always accumulates and returns all raw page results. -func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opts PaginationOptions, onResult func(interface{})) ([]interface{}, error) { +func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opts PaginationOptions, onResult func(interface{}) error) ([]interface{}, error) { var allResults []interface{} var pageToken string page := 0 @@ -399,7 +399,9 @@ func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opt } if onResult != nil { - onResult(result) + if err := onResult(result); err != nil { + return allResults, err + } } allResults = append(allResults, result) @@ -452,28 +454,31 @@ func (c *APIClient) PaginateAll(ctx context.Context, request RawApiRequest, opts // StreamPages fetches all pages and streams each page's list items via onItems. // Returns the last page result (for error checking), whether any list items were found, // and any network error. Use this for streaming formats (ndjson, table, csv). -func (c *APIClient) StreamPages(ctx context.Context, request RawApiRequest, onItems func([]interface{}), opts PaginationOptions) (result interface{}, hasItems bool, err error) { +func (c *APIClient) StreamPages(ctx context.Context, request RawApiRequest, onItems func([]interface{}) error, opts PaginationOptions) (result interface{}, hasItems bool, err error) { totalItems := 0 - results, loopErr := c.paginateLoop(ctx, request, opts, func(r interface{}) { + results, loopErr := c.paginateLoop(ctx, request, opts, func(r interface{}) error { resultMap, ok := r.(map[string]interface{}) if !ok { - return + return nil } data, ok := resultMap["data"].(map[string]interface{}) if !ok { - return + return nil } arrayField := output.FindArrayField(data) if arrayField == "" { - return + return nil } items, ok := data[arrayField].([]interface{}) if !ok { - return + return nil } totalItems += len(items) - onItems(items) + if err := onItems(items); err != nil { + return err + } hasItems = true + return nil }) if loopErr != nil { return nil, false, loopErr diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 8cf38d95..cdc1d81c 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -124,8 +124,9 @@ func TestStreamPages_NonBatchAPI_NoArrayField(t *testing.T) { Method: "GET", URL: "/open-apis/contact/v3/users/u123", As: "bot", - }, func(items []interface{}) { + }, func(items []interface{}) error { t.Error("onItems should not be called for non-batch API") + return nil }, PaginationOptions{}) if err != nil { @@ -168,8 +169,9 @@ func TestStreamPages_BatchAPI_WithArrayField(t *testing.T) { Method: "GET", URL: "/open-apis/contact/v3/users", As: "bot", - }, func(items []interface{}) { + }, func(items []interface{}) error { streamedItems = append(streamedItems, items...) + return nil }, PaginationOptions{}) if err != nil { @@ -189,6 +191,58 @@ func TestStreamPages_BatchAPI_WithArrayField(t *testing.T) { } } +func TestStreamPages_OnItemsErrorStopsPagination(t *testing.T) { + apiCalls := 0 + rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { + apiCalls++ + if apiCalls == 1 { + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": true, + "page_token": "next", + }, + }), nil + } + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "2"}}, + "has_more": false, + }, + }), nil + }) + + ac, _ := newTestAPIClient(t, rt) + sentinel := errors.New("stop streaming") + var streamedItems []interface{} + result, hasItems, err := ac.StreamPages(context.Background(), RawApiRequest{ + Method: "GET", + URL: "/open-apis/contact/v3/users", + As: "bot", + }, func(items []interface{}) error { + streamedItems = append(streamedItems, items...) + return sentinel + }, PaginationOptions{PageDelay: 0}) + + if !errors.Is(err, sentinel) { + t.Fatalf("err = %v, want sentinel", err) + } + if result != nil { + t.Fatalf("result = %#v, want nil when callback stops pagination", result) + } + if hasItems { + t.Fatal("hasItems = true, want false when callback stops before returning") + } + if apiCalls != 1 { + t.Fatalf("apiCalls = %d, want early stop after first page", apiCalls) + } + if len(streamedItems) != 1 { + t.Fatalf("streamedItems = %d, want first page only", len(streamedItems)) + } +} + func TestPaginateAll_PageLimitStopsPagination(t *testing.T) { apiCalls := 0 rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { diff --git a/internal/client/pagination.go b/internal/client/pagination.go index 66b064b4..91dc067e 100644 --- a/internal/client/pagination.go +++ b/internal/client/pagination.go @@ -4,7 +4,6 @@ package client import ( - "context" "fmt" "io" @@ -19,33 +18,6 @@ type PaginationOptions struct { Identity core.Identity // identity passed to checkErr; defaults to AsUser when empty } -// PaginateWithJq aggregates all pages, checks for API errors, then applies a jq filter. -// If checkErr detects an error, the raw result is printed as JSON before returning the error. -func PaginateWithJq(ctx context.Context, ac *APIClient, request RawApiRequest, - jqExpr string, out io.Writer, pagOpts PaginationOptions, - checkErr func(interface{}, core.Identity) error) error { - result, err := ac.PaginateAll(ctx, request, pagOpts) - if err != nil { - return err - } - // Identity resolution honors pagOpts.Identity first, then the request's - // own identity, and only falls back to AsUser when neither caller - // supplied one. Without checking request.As, bot/auto requests would - // always be classified as user identity for checkErr. - identity := pagOpts.Identity - if identity == "" { - identity = request.As - } - if identity == "" || identity == core.AsAuto { - identity = core.AsUser - } - if apiErr := checkErr(result, identity); apiErr != nil { - output.FormatValue(out, result, output.FormatJSON) - return apiErr - } - return output.JqFilter(out, result, jqExpr) -} - func mergePagedResults(w io.Writer, results []interface{}) interface{} { if len(results) == 0 { return map[string]interface{}{} diff --git a/internal/client/response.go b/internal/client/response.go index fbd88f7c..742ad148 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -89,23 +89,37 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error { if apiErr := check(result, identity); apiErr != nil { return apiErr } - // Content safety scanning - scanResult := output.ScanForSafety(opts.CommandPath, result, opts.ErrOut) - if scanResult.Blocked { - return scanResult.BlockErr - } if opts.OutputPath != "" { + // File downloads keep the existing raw-response scan path because the + // saved payload is the API response body, not the success envelope. + scanResult := output.ScanForSafety(opts.CommandPath, result, opts.ErrOut) + if scanResult.Blocked { + return scanResult.BlockErr + } if scanResult.Alert != nil { output.WriteAlertWarning(opts.ErrOut, scanResult.Alert) } return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out) } + + if opts.JqExpr != "" || opts.Format == output.FormatJSON { + return output.WriteSuccessEnvelope(output.SuccessEnvelopeData(result), output.SuccessEnvelopeOptions{ + CommandPath: opts.CommandPath, + Identity: string(identity), + JqExpr: opts.JqExpr, + Out: opts.Out, + ErrOut: opts.ErrOut, + }) + } + + // Content safety scanning for non-JSON presentation formats. + scanResult := output.ScanForSafety(opts.CommandPath, result, opts.ErrOut) + if scanResult.Blocked { + return scanResult.BlockErr + } if scanResult.Alert != nil { output.WriteAlertWarning(opts.ErrOut, scanResult.Alert) } - if opts.JqExpr != "" { - return output.JqFilter(opts.Out, result, opts.JqExpr) - } output.FormatValue(opts.Out, result, opts.Format) return nil } diff --git a/internal/client/response_test.go b/internal/client/response_test.go index 0902e555..ac249833 100644 --- a/internal/client/response_test.go +++ b/internal/client/response_test.go @@ -5,6 +5,7 @@ package client import ( "bytes" + "encoding/json" "errors" "io" "net/http" @@ -16,6 +17,7 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/vfs/localfileio" ) @@ -207,15 +209,54 @@ func TestHandleResponse_JSON(t *testing.T) { var out bytes.Buffer var errOut bytes.Buffer err := HandleResponse(resp, ResponseOptions{ - Out: &out, - ErrOut: &errOut, - FileIO: &localfileio.LocalFileIO{}, + Identity: core.AsBot, + Out: &out, + ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, }) if err != nil { t.Fatalf("HandleResponse failed: %v", err) } - if !bytes.Contains(out.Bytes(), []byte(`"code"`)) { - t.Errorf("expected JSON output, got: %s", out.String()) + var got map[string]interface{} + if err := json.Unmarshal(out.Bytes(), &got); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, out.String()) + } + if got["ok"] != true { + t.Fatalf("ok = %v, want true; output: %s", got["ok"], out.String()) + } + if got["identity"] != "bot" { + t.Fatalf("identity = %v, want bot; output: %s", got["identity"], out.String()) + } + if _, hasCode := got["code"]; hasCode { + t.Fatalf("success envelope leaked outer code field: %s", out.String()) + } + data, ok := got["data"].(map[string]interface{}) + if !ok { + t.Fatalf("data = %T, want object; output: %s", got["data"], out.String()) + } + if data["id"] != "1" { + t.Fatalf("data.id = %v, want 1; output: %s", data["id"], out.String()) + } +} + +func TestHandleResponse_JSONWithJqUsesSuccessEnvelope(t *testing.T) { + body := []byte(`{"code":0,"msg":"ok","data":{"id":"1"}}`) + resp := newApiResp(body, map[string]string{"Content-Type": "application/json"}) + + var out bytes.Buffer + var errOut bytes.Buffer + err := HandleResponse(resp, ResponseOptions{ + Identity: core.AsBot, + JqExpr: ".data.id", + Out: &out, + ErrOut: &errOut, + FileIO: &localfileio.LocalFileIO{}, + }) + if err != nil { + t.Fatalf("HandleResponse failed: %v", err) + } + if strings.TrimSpace(out.String()) != "1" { + t.Fatalf("jq output = %q, want %q", out.String(), "1") } } @@ -233,6 +274,12 @@ func TestHandleResponse_JSONWithError(t *testing.T) { if err == nil { t.Error("expected error for non-zero code") } + if _, ok := errs.ProblemOf(err); !ok { + t.Fatalf("expected typed error, got %T: %v", err, err) + } + if strings.Contains(out.String(), `"ok": true`) || strings.Contains(out.String(), `"ok":true`) { + t.Fatalf("unexpected success envelope on error path: %s", out.String()) + } } func TestHandleResponse_BinaryAutoSave(t *testing.T) { diff --git a/internal/output/emit.go b/internal/output/emit.go index dfc4598b..80206ebe 100644 --- a/internal/output/emit.go +++ b/internal/output/emit.go @@ -9,6 +9,7 @@ import ( "io" "strings" + "github.com/larksuite/cli/errs" extcs "github.com/larksuite/cli/extension/contentsafety" ) @@ -35,19 +36,16 @@ func ScanForSafety(cmdPath string, data any, errOut io.Writer) ScanResult { return ScanResult{Alert: alert} } -// wrapBlockError creates an ExitError for content-safety block. +// wrapBlockError creates a typed error for content-safety block. func wrapBlockError(alert *extcs.Alert) error { - rules := "" + var matchedRules []string if alert != nil { - rules = strings.Join(alert.MatchedRules, ", ") - } - return &ExitError{ - Code: ExitContentSafety, - Detail: &ErrDetail{ - Type: "content_safety_blocked", - Message: fmt.Sprintf("content safety violation detected (rules: %s)", rules), - }, + matchedRules = alert.MatchedRules } + return errs.NewContentSafetyError(errs.SubtypeContentSafety, + "content safety violation detected (rules: %s)", strings.Join(matchedRules, ", ")). + WithRules(matchedRules...). + WithCause(errBlocked) } // WriteAlertWarning writes a human-readable content-safety warning to w. diff --git a/internal/output/emit_test.go b/internal/output/emit_test.go index a25c1e62..b81a7fad 100644 --- a/internal/output/emit_test.go +++ b/internal/output/emit_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/larksuite/cli/errs" extcs "github.com/larksuite/cli/extension/contentsafety" ) @@ -72,12 +73,18 @@ func TestScanForSafety_ModeBlock_WithAlert(t *testing.T) { if result.BlockErr == nil { t.Error("block mode with alert should have BlockErr") } - var exitErr *ExitError - if !errors.As(result.BlockErr, &exitErr) { - t.Fatalf("BlockErr should be *ExitError, got %T", result.BlockErr) + var safetyErr *errs.ContentSafetyError + if !errors.As(result.BlockErr, &safetyErr) { + t.Fatalf("BlockErr should be *ContentSafetyError, got %T", result.BlockErr) } - if exitErr.Code != ExitContentSafety { - t.Errorf("exit code = %d, want %d", exitErr.Code, ExitContentSafety) + if safetyErr.Category != errs.CategoryPolicy || safetyErr.Subtype != errs.SubtypeContentSafety { + t.Errorf("problem = %s/%s, want %s/%s", safetyErr.Category, safetyErr.Subtype, errs.CategoryPolicy, errs.SubtypeContentSafety) + } + if len(safetyErr.Rules) != 1 || safetyErr.Rules[0] != "r1" { + t.Errorf("rules = %v, want [r1]", safetyErr.Rules) + } + if !errors.Is(result.BlockErr, errBlocked) { + t.Error("BlockErr should preserve errBlocked cause") } } diff --git a/internal/output/envelope_success.go b/internal/output/envelope_success.go new file mode 100644 index 00000000..08649c33 --- /dev/null +++ b/internal/output/envelope_success.go @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import "io" + +// SuccessEnvelopeOptions configures the shortcut-compatible success envelope. +type SuccessEnvelopeOptions struct { + CommandPath string + Identity string + JqExpr string + Out io.Writer + ErrOut io.Writer +} + +// SuccessEnvelopeData extracts the business payload for the standard success +// envelope from a Lark API response. Outer code/msg fields are transport +// protocol details and are intentionally not exposed as business data. +func SuccessEnvelopeData(result interface{}) interface{} { + m, ok := result.(map[string]interface{}) + if !ok { + return map[string]interface{}{} + } + data, ok := m["data"] + if !ok || data == nil { + return map[string]interface{}{} + } + return data +} + +// WriteSuccessEnvelope emits the standard success envelope used by shortcuts. +// JSON output carries content-safety alerts inside the envelope. When jq is +// applied, the alert may be filtered away, so warn mode also writes stderr. +func WriteSuccessEnvelope(data interface{}, opts SuccessEnvelopeOptions) error { + scanResult := ScanForSafety(opts.CommandPath, data, opts.ErrOut) + if scanResult.Blocked { + return scanResult.BlockErr + } + + env := Envelope{ + OK: true, + Identity: opts.Identity, + Data: data, + Notice: GetNotice(), + } + if scanResult.Alert != nil { + env.ContentSafetyAlert = scanResult.Alert + } + if opts.JqExpr != "" { + if scanResult.Alert != nil && opts.ErrOut != nil { + WriteAlertWarning(opts.ErrOut, scanResult.Alert) + } + return JqFilter(opts.Out, env, opts.JqExpr) + } + PrintJson(opts.Out, env) + return nil +} diff --git a/internal/output/envelope_success_test.go b/internal/output/envelope_success_test.go new file mode 100644 index 00000000..dac7c17f --- /dev/null +++ b/internal/output/envelope_success_test.go @@ -0,0 +1,173 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package output + +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/larksuite/cli/errs" + extcs "github.com/larksuite/cli/extension/contentsafety" +) + +func TestSuccessEnvelopeData_ExtractsBusinessData(t *testing.T) { + result := map[string]interface{}{ + "code": float64(0), + "msg": "ok", + "data": map[string]interface{}{"id": "1"}, + } + + got := SuccessEnvelopeData(result) + m, ok := got.(map[string]interface{}) + if !ok { + t.Fatalf("business data type = %T, want map", got) + } + if m["id"] != "1" { + t.Fatalf("id = %v, want 1", m["id"]) + } + if _, ok := m["code"]; ok { + t.Fatal("business data must not contain outer code") + } +} + +func TestSuccessEnvelopeData_MissingDataUsesEmptyObject(t *testing.T) { + got := SuccessEnvelopeData(map[string]interface{}{"code": float64(0), "msg": "ok"}) + m, ok := got.(map[string]interface{}) + if !ok { + t.Fatalf("business data type = %T, want map", got) + } + if len(m) != 0 { + t.Fatalf("business data = %#v, want empty object", m) + } +} + +func TestSuccessEnvelopeData_NilDataUsesEmptyObject(t *testing.T) { + got := SuccessEnvelopeData(map[string]interface{}{"code": float64(0), "msg": "ok", "data": nil}) + m, ok := got.(map[string]interface{}) + if !ok { + t.Fatalf("business data type = %T, want map", got) + } + if len(m) != 0 { + t.Fatalf("business data = %#v, want empty object", m) + } +} + +func TestWriteSuccessEnvelope_PrintsShortcutCompatibleEnvelope(t *testing.T) { + var out strings.Builder + + err := WriteSuccessEnvelope(map[string]interface{}{"id": "1"}, SuccessEnvelopeOptions{ + Identity: "bot", + Out: &out, + }) + if err != nil { + t.Fatalf("WriteSuccessEnvelope() error = %v", err) + } + + var env map[string]interface{} + if err := json.Unmarshal([]byte(out.String()), &env); err != nil { + t.Fatalf("invalid JSON output: %v\n%s", err, out.String()) + } + if env["ok"] != true || env["identity"] != "bot" { + t.Fatalf("unexpected envelope: %#v", env) + } + data, ok := env["data"].(map[string]interface{}) + if !ok || data["id"] != "1" { + t.Fatalf("unexpected data payload: %#v", env["data"]) + } + if _, ok := env["code"]; ok { + t.Fatalf("output leaked protocol field code: %#v", env) + } + if _, ok := env["msg"]; ok { + t.Fatalf("output leaked protocol field msg: %#v", env) + } + if _, ok := env["_content_safety_alert"]; ok { + t.Fatalf("output should omit empty content-safety alert: %#v", env) + } +} + +func TestWriteSuccessEnvelope_JqUsesEnvelope(t *testing.T) { + var out strings.Builder + + err := WriteSuccessEnvelope(map[string]interface{}{"id": "1"}, SuccessEnvelopeOptions{ + Identity: "bot", + JqExpr: ".data.id", + Out: &out, + }) + if err != nil { + t.Fatalf("WriteSuccessEnvelope() error = %v", err) + } + if strings.TrimSpace(out.String()) != "1" { + t.Fatalf("jq output = %q, want %q", out.String(), "1") + } +} + +func TestWriteSuccessEnvelope_JqWarnsWhenSafetyAlertFiltered(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "warn") + extcs.Register(&mockProvider{ + name: "mock", + alert: &extcs.Alert{Provider: "mock", MatchedRules: []string{"r1"}}, + }) + t.Cleanup(func() { extcs.Register(nil) }) + + var out strings.Builder + var errOut strings.Builder + err := WriteSuccessEnvelope(map[string]interface{}{"id": "1"}, SuccessEnvelopeOptions{ + CommandPath: "lark-cli im +test", + Identity: "bot", + JqExpr: ".data.id", + Out: &out, + ErrOut: &errOut, + }) + if err != nil { + t.Fatalf("WriteSuccessEnvelope() error = %v", err) + } + if strings.TrimSpace(out.String()) != "1" { + t.Fatalf("jq output = %q, want %q", out.String(), "1") + } + if !strings.Contains(errOut.String(), "warning: content safety alert from mock") { + t.Fatalf("expected content safety warning on stderr, got: %s", errOut.String()) + } + if !strings.Contains(errOut.String(), "r1") { + t.Fatalf("expected rule in stderr warning, got: %s", errOut.String()) + } +} + +func TestWriteSuccessEnvelope_BlockModeReturnsTypedErrorWithoutStdout(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONTENT_SAFETY_MODE", "block") + extcs.Register(&mockProvider{ + name: "mock", + alert: &extcs.Alert{Provider: "mock", MatchedRules: []string{"r1"}}, + }) + t.Cleanup(func() { extcs.Register(nil) }) + + var out strings.Builder + var errOut strings.Builder + err := WriteSuccessEnvelope(map[string]interface{}{"id": "1"}, SuccessEnvelopeOptions{ + CommandPath: "lark-cli im +test", + Identity: "bot", + Out: &out, + ErrOut: &errOut, + }) + if err == nil { + t.Fatal("expected content safety block error") + } + var safetyErr *errs.ContentSafetyError + if !errors.As(err, &safetyErr) { + t.Fatalf("expected ContentSafetyError, got %T: %v", err, err) + } + if safetyErr.Category != errs.CategoryPolicy || safetyErr.Subtype != errs.SubtypeContentSafety { + t.Fatalf("problem = %s/%s, want %s/%s", safetyErr.Category, safetyErr.Subtype, errs.CategoryPolicy, errs.SubtypeContentSafety) + } + if len(safetyErr.Rules) != 1 || safetyErr.Rules[0] != "r1" { + t.Fatalf("rules = %v, want [r1]", safetyErr.Rules) + } + if !errors.Is(err, errBlocked) { + t.Fatal("content safety error should preserve errBlocked cause") + } + if out.String() != "" { + t.Fatalf("stdout should stay empty on block, got: %s", out.String()) + } +} diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index ae3cbd19..bfecaa19 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -413,7 +413,10 @@ func (ctx *RuntimeContext) StreamPages(method, url string, params map[string]int return nil, false, err } req := ctx.buildRequest(method, url, params, data) - return ac.StreamPages(ctx.ctx, req, onItems, opts) + return ac.StreamPages(ctx.ctx, req, func(items []interface{}) error { + onItems(items) + return nil + }, opts) } func (ctx *RuntimeContext) buildRequest(method, url string, params map[string]interface{}, data interface{}) client.RawApiRequest {