From b3fcf556119160a4894b6212f02887ed6661eff9 Mon Sep 17 00:00:00 2001 From: evandance <120630830+evandance@users.noreply.github.com> Date: Wed, 3 Jun 2026 19:20:19 +0800 Subject: [PATCH] feat(common): emit typed validation errors from shared shortcut pre-checks (#1242) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Input pre-check failures shared by every shortcut — @file/stdin input resolution, enum validation, and unsupported --dry-run — now leave the CLI as typed validation envelopes naming the offending flag, so scripts and AI agents can branch on `param` instead of parsing prose. Wire type, exit code, and message text are unchanged; the new fields are additive. The shared layer also gains typed replacements for its legacy error-producing helpers, so each business domain can migrate to typed errors without rebuilding common plumbing, and a path-scoped lint guard keeps migrated domains from sliding back. Changes: - Shared pre-check failures (input flags, enum values, dry-run support) return typed validation errors carrying the offending flag as `param`. - Every legacy error-producing helper in shortcuts/common has a typed replacement that preserves the existing message text: validation and flag-group checks, chat/user ID validation (callers name the flag so `param` is ground truth), "me" open-id resolution, safe-path checks, input-stat and save-error wrapping. Legacy helpers stay for not-yet-migrated domains, marked deprecated — including the legacy API-result classifier, whose typed route is runtime.CallAPITyped. - A new errscontract rule rejects legacy common-helper calls on migrated paths, so a migrated domain cannot silently reintroduce legacy envelopes; drive is the first locked path and its last legacy ID-helper calls are replaced. --- .../rule_no_legacy_common_helper_call.go | 138 +++++++++++++++++ lint/errscontract/rules_test.go | 120 +++++++++++++++ lint/errscontract/scan.go | 1 + shortcuts/common/common.go | 3 + shortcuts/common/runner.go | 72 +++++++-- shortcuts/common/runner_input_test.go | 5 + shortcuts/common/runner_validation_test.go | 22 +++ shortcuts/common/userids.go | 30 +++- shortcuts/common/userids_test.go | 21 +++ shortcuts/common/validate.go | 115 ++++++++++++++ shortcuts/common/validate_ids.go | 54 ++++++- shortcuts/common/validate_test.go | 142 ++++++++++++++++++ shortcuts/drive/drive_search.go | 6 +- 13 files changed, 708 insertions(+), 21 deletions(-) create mode 100644 lint/errscontract/rule_no_legacy_common_helper_call.go create mode 100644 shortcuts/common/runner_validation_test.go diff --git a/lint/errscontract/rule_no_legacy_common_helper_call.go b/lint/errscontract/rule_no_legacy_common_helper_call.go new file mode 100644 index 00000000..06886cde --- /dev/null +++ b/lint/errscontract/rule_no_legacy_common_helper_call.go @@ -0,0 +1,138 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package errscontract + +import ( + "go/ast" + "go/parser" + "go/token" + "strings" +) + +// migratedCommonHelperPaths lists source-tree prefixes whose command validation +// has migrated to typed errs.* envelopes. On these paths, calls to common's +// legacy validation/save helpers are forbidden; callers must use the typed +// common replacements or construct an errs.* typed error directly. +var migratedCommonHelperPaths = []string{ + "shortcuts/drive/", +} + +const commonImportPath = "github.com/larksuite/cli/shortcuts/common" + +var legacyCommonHelperReplacements = map[string]string{ + "FlagErrorf": "common.ValidationErrorf", + "MutuallyExclusive": "common.MutuallyExclusiveTyped", + "AtLeastOne": "common.AtLeastOneTyped", + "ExactlyOne": "common.ExactlyOneTyped", + "ValidatePageSize": "common.ValidatePageSizeTyped", + "ValidateChatID": "common.ValidateChatIDTyped", + "ValidateUserID": "common.ValidateUserIDTyped", + "ValidateSafePath": "common.ValidateSafePathTyped", + "RejectDangerousChars": "common.RejectDangerousCharsTyped", + "WrapInputStatError": "common.WrapInputStatErrorTyped", + "WrapSaveErrorByCategory": "common.WrapSaveErrorTyped", + "ResolveOpenIDs": "common.ResolveOpenIDsTyped", + "HandleApiResult": "runtime.CallAPITyped", +} + +// CheckNoLegacyCommonHelperCall flags any reference to common's legacy helper +// APIs on migrated paths — direct calls and function-value references alike, +// so `f := common.FlagErrorf; f(...)` cannot slip past the guard. These +// helpers return legacy output envelopes or bare errors, so migrated domains +// should use their typed-aware replacements. +func CheckNoLegacyCommonHelperCall(path, src string) []Violation { + if !isMigratedCommonHelperPath(path) || strings.HasSuffix(path, "_test.go") { + return nil + } + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, path, src, parser.ParseComments) + if err != nil { + return nil + } + localNames, dotImported := resolveCommonNames(file) + var out []Violation + report := func(pos token.Pos, name, replacement string) { + out = append(out, Violation{ + Rule: "no_legacy_common_helper_call", + Action: ActionReject, + File: path, + Line: fset.Position(pos).Line, + Message: "common." + name + " returns a legacy error shape and is forbidden on migrated paths", + Suggestion: "replace common." + name + " with " + replacement + " or a typed errs.* constructor", + }) + } + // Pass 1: qualified references (common.X / alias.X). Record every + // selector field so the dot-import pass below never mistakes another + // package's same-named field for a common helper. + selFields := make(map[*ast.Ident]struct{}) + ast.Inspect(file, func(n ast.Node) bool { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return true + } + selFields[sel.Sel] = struct{}{} + x, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + if _, bound := localNames[x.Name]; !bound { + return true + } + if replacement, ok := legacyCommonHelperReplacements[sel.Sel.Name]; ok { + report(sel.Pos(), sel.Sel.Name, replacement) + } + return true + }) + // Pass 2: unqualified references under a dot import. + if dotImported { + ast.Inspect(file, func(n ast.Node) bool { + ident, ok := n.(*ast.Ident) + if !ok { + return true + } + if _, isField := selFields[ident]; isField { + return true + } + if replacement, ok := legacyCommonHelperReplacements[ident.Name]; ok { + report(ident.Pos(), ident.Name, replacement) + } + return true + }) + } + return out +} + +func isMigratedCommonHelperPath(path string) bool { + p := strings.ReplaceAll(path, "\\", "/") + for _, prefix := range migratedCommonHelperPaths { + if strings.HasPrefix(p, prefix) || strings.Contains(p, "/"+prefix) { + return true + } + } + return false +} + +func resolveCommonNames(file *ast.File) (map[string]struct{}, bool) { + names := make(map[string]struct{}) + dotImported := false + for _, imp := range file.Imports { + if imp.Path == nil { + continue + } + p := strings.Trim(imp.Path.Value, "`\"") + if p != commonImportPath { + continue + } + switch { + case imp.Name == nil: + names["common"] = struct{}{} + case imp.Name.Name == ".": + dotImported = true + case imp.Name.Name == "_": + default: + names[imp.Name.Name] = struct{}{} + } + } + return names, dotImported +} diff --git a/lint/errscontract/rules_test.go b/lint/errscontract/rules_test.go index 9beed919..1579e38d 100644 --- a/lint/errscontract/rules_test.go +++ b/lint/errscontract/rules_test.go @@ -877,3 +877,123 @@ func boom(runtime *common.RuntimeContext) error { t.Errorf("test files must be skipped, got: %+v", v) } } + +func TestCheckNoLegacyCommonHelperCall_RejectsLegacyHelpersOnMigratedPath(t *testing.T) { + helpers := []string{ + "FlagErrorf", + "MutuallyExclusive", + "AtLeastOne", + "ExactlyOne", + "ValidatePageSize", + "ValidateChatID", + "ValidateUserID", + "ValidateSafePath", + "RejectDangerousChars", + "WrapInputStatError", + "WrapSaveErrorByCategory", + "ResolveOpenIDs", + "HandleApiResult", + } + for _, helper := range helpers { + t.Run(helper, func(t *testing.T) { + src := `package drive + +import "github.com/larksuite/cli/shortcuts/common" + +func boom() { + common.` + helper + `() +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src) + if len(v) != 1 { + t.Fatalf("expected 1 violation for %s, got %d: %+v", helper, len(v), v) + } + if v[0].Action != ActionReject { + t.Errorf("action = %q, want REJECT", v[0].Action) + } + if !strings.Contains(v[0].Message, "common."+helper) { + t.Errorf("message should name helper %s: %s", helper, v[0].Message) + } + }) + } +} + +func TestCheckNoLegacyCommonHelperCall_AllowsNonMigratedPath(t *testing.T) { + src := `package im + +import "github.com/larksuite/cli/shortcuts/common" + +func boom() { + common.FlagErrorf("legacy allowed until domain migrates") +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/im/im_send.go", src) + if len(v) != 0 { + t.Errorf("non-migrated path must pass, got: %+v", v) + } +} + +func TestCheckNoLegacyCommonHelperCall_AllowsTypedHelpersOnMigratedPath(t *testing.T) { + src := `package drive + +import "github.com/larksuite/cli/shortcuts/common" + +func boom() { + common.ValidationErrorf("typed") + common.MutuallyExclusiveTyped(nil, "a", "b") + common.ValidateChatIDTyped("--chat-ids", "oc_abc") + common.ResolveOpenIDsTyped("--user-ids", nil, nil) + common.WrapSaveErrorTyped(nil) +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src) + if len(v) != 0 { + t.Errorf("typed helpers must pass, got: %+v", v) + } +} + +func TestCheckNoLegacyCommonHelperCall_RejectsAliasedImport(t *testing.T) { + src := `package drive + +import c "github.com/larksuite/cli/shortcuts/common" + +func boom() { + c.FlagErrorf("legacy") +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src) + if len(v) != 1 { + t.Fatalf("expected 1 violation for aliased common import, got %d: %+v", len(v), v) + } +} + +func TestCheckNoLegacyCommonHelperCall_RejectsDotImport(t *testing.T) { + src := `package drive + +import . "github.com/larksuite/cli/shortcuts/common" + +func boom() { + FlagErrorf("legacy") +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src) + if len(v) != 1 { + t.Fatalf("expected 1 violation for dot-imported common, got %d: %+v", len(v), v) + } +} + +func TestCheckNoLegacyCommonHelperCall_RejectsFunctionValueReference(t *testing.T) { + src := `package drive + +import "github.com/larksuite/cli/shortcuts/common" + +func boom() error { + f := common.FlagErrorf + return f("legacy") +} +` + v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src) + if len(v) != 1 { + t.Fatalf("expected 1 violation for function-value reference, got %d: %+v", len(v), v) + } +} diff --git a/lint/errscontract/scan.go b/lint/errscontract/scan.go index dd4c5482..d7953ae0 100644 --- a/lint/errscontract/scan.go +++ b/lint/errscontract/scan.go @@ -108,6 +108,7 @@ func ScanRepo(root string) ([]Violation, error) { all = append(all, CheckTypedErrorCompleteness(rel, string(src))...) all = append(all, CheckNoLegacyEnvelopeLiteral(rel, string(src))...) all = append(all, CheckNoLegacyRuntimeAPICall(rel, string(src))...) + all = append(all, CheckNoLegacyCommonHelperCall(rel, string(src))...) // Typed-error invariants — self-scope to errs/ + classify.go. all = append(all, CheckNilSafeError(rel, string(src))...) all = append(all, CheckUnwrapSymmetry(rel, string(src))...) diff --git a/shortcuts/common/common.go b/shortcuts/common/common.go index eeb11f58..42ca4aff 100644 --- a/shortcuts/common/common.go +++ b/shortcuts/common/common.go @@ -164,6 +164,9 @@ func CheckApiError(w io.Writer, result interface{}, action string) bool { } // HandleApiResult checks for network/API errors and returns the "data" field. +// +// Deprecated: use RuntimeContext.CallAPITyped (or ClassifyAPIResponse for +// self-driven requests) for typed error envelopes. func HandleApiResult(result interface{}, err error, action string) (map[string]interface{}, error) { if err != nil { return nil, output.Errorf(output.ExitAPI, "api_error", "%s: %s", action, err) diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index e091754b..4c533896 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -625,6 +625,8 @@ func WrapOpenError(err error, pathMsg, readMsg string) error { // - Other errors → readMsg prefix (default "cannot read file") // // Pass an optional readMsg to override the non-path-validation message prefix. +// +// Deprecated: use WrapInputStatErrorTyped for typed error envelopes. func WrapInputStatError(err error, readMsg ...string) error { if err == nil { return nil @@ -639,9 +641,28 @@ func WrapInputStatError(err error, readMsg ...string) error { return output.ErrValidation("%s: %s", msg, err) } +// WrapInputStatErrorTyped wraps a FileIO.Stat/Open error for input file validation. +func WrapInputStatErrorTyped(err error, readMsg ...string) error { + if err == nil { + return nil + } + if errors.Is(err, fileio.ErrPathValidation) { + return errs.NewValidationError(errs.SubtypeInvalidArgument, "unsafe file path: %s", err). + WithCause(err) + } + msg := "cannot read file" + if len(readMsg) > 0 && readMsg[0] != "" { + msg = readMsg[0] + } + return errs.NewValidationError(errs.SubtypeInvalidArgument, "%s: %s", msg, err). + WithCause(err) +} + // WrapSaveErrorByCategory maps a FileIO.Save error to structured output errors, // using standardized messages and the given error category (e.g. "api_error", "io"). // Path validation errors always use ErrValidation (exit code 2). +// +// Deprecated: use WrapSaveErrorTyped for typed error envelopes. func WrapSaveErrorByCategory(err error, category string) error { if err == nil { return nil @@ -657,6 +678,28 @@ func WrapSaveErrorByCategory(err error, category string) error { } } +// WrapSaveErrorTyped maps a FileIO.Save error to typed validation/internal errors. +// Unlike WrapSaveErrorByCategory, non-path failures always emit the canonical +// "internal" wire type: call sites migrating from a custom category +// (e.g. "io", "api_error") change their envelope's type field. +func WrapSaveErrorTyped(err error) error { + if err == nil { + return nil + } + var me *fileio.MkdirError + switch { + case errors.Is(err, fileio.ErrPathValidation): + return errs.NewValidationError(errs.SubtypeInvalidArgument, "unsafe output path: %s", err). + WithCause(err) + case errors.As(err, &me): + return errs.NewInternalError(errs.SubtypeFileIO, "cannot create parent directory: %s", err). + WithCause(err) + default: + return errs.NewInternalError(errs.SubtypeFileIO, "cannot create file: %s", err). + WithCause(err) + } +} + // ValidatePath checks that path is a valid relative input path within the // working directory by delegating to FileIO.Stat. Returns nil if the path is // valid or does not exist yet; returns an error only for illegal paths @@ -1022,7 +1065,8 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error { } raw, err := rctx.Cmd.Flags().GetString(fl.Name) if err != nil { - return FlagErrorf("--%s: Input is only supported for string flags", fl.Name) + return ValidationErrorf("--%s: Input is only supported for string flags", fl.Name). + WithParam("--" + fl.Name) } if raw == "" { continue @@ -1031,15 +1075,19 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error { // stdin: - if raw == "-" { if !slices.Contains(fl.Input, Stdin) { - return FlagErrorf("--%s does not support stdin (-)", fl.Name) + return ValidationErrorf("--%s does not support stdin (-)", fl.Name). + WithParam("--" + fl.Name) } if stdinUsed { - return FlagErrorf("--%s: stdin (-) can only be used by one flag", fl.Name) + return ValidationErrorf("--%s: stdin (-) can only be used by one flag", fl.Name). + WithParam("--" + fl.Name) } stdinUsed = true data, err := io.ReadAll(rctx.IO().In) if err != nil { - return FlagErrorf("--%s: failed to read from stdin: %v", fl.Name, err) + return ValidationErrorf("--%s: failed to read from stdin: %v", fl.Name, err). + WithParam("--" + fl.Name). + WithCause(err) } rctx.Cmd.Flags().Set(fl.Name, string(data)) continue @@ -1054,15 +1102,19 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error { // file: @path if strings.HasPrefix(raw, "@") { if !slices.Contains(fl.Input, File) { - return FlagErrorf("--%s does not support file input (@path)", fl.Name) + return ValidationErrorf("--%s does not support file input (@path)", fl.Name). + WithParam("--" + fl.Name) } path := strings.TrimSpace(raw[1:]) if path == "" { - return FlagErrorf("--%s: file path cannot be empty after @", fl.Name) + return ValidationErrorf("--%s: file path cannot be empty after @", fl.Name). + WithParam("--" + fl.Name) } data, err := cmdutil.ReadInputFile(rctx.FileIO(), path) if err != nil { - return FlagErrorf("--%s: %v", fl.Name, err) + return ValidationErrorf("--%s: %v", fl.Name, err). + WithParam("--" + fl.Name). + WithCause(err) } rctx.Cmd.Flags().Set(fl.Name, string(data)) continue @@ -1088,7 +1140,8 @@ func validateEnumFlags(rctx *RuntimeContext, flags []Flag) error { } } if !valid { - return FlagErrorf("invalid value %q for --%s, allowed: %s", val, fl.Name, strings.Join(fl.Enum, ", ")) + return ValidationErrorf("invalid value %q for --%s, allowed: %s", val, fl.Name, strings.Join(fl.Enum, ", ")). + WithParam("--" + fl.Name) } } return nil @@ -1096,7 +1149,8 @@ func validateEnumFlags(rctx *RuntimeContext, flags []Flag) error { func handleShortcutDryRun(f *cmdutil.Factory, rctx *RuntimeContext, s *Shortcut) error { if s.DryRun == nil { - return FlagErrorf("--dry-run is not supported for %s %s", s.Service, s.Command) + return ValidationErrorf("--dry-run is not supported for %s %s", s.Service, s.Command). + WithParam("--dry-run") } fmt.Fprintln(f.IOStreams.ErrOut, "=== Dry Run ===") dryResult := s.DryRun(rctx.ctx, rctx) diff --git a/shortcuts/common/runner_input_test.go b/shortcuts/common/runner_input_test.go index 47a42c13..1d44023a 100644 --- a/shortcuts/common/runner_input_test.go +++ b/shortcuts/common/runner_input_test.go @@ -129,6 +129,7 @@ func TestResolveInputFlags_StdinNotSupported(t *testing.T) { if err == nil { t.Fatal("expected error for stdin not supported") } + assertValidationParam(t, err, "--data") if !strings.Contains(err.Error(), "does not support stdin") { t.Errorf("unexpected error: %v", err) } @@ -142,6 +143,7 @@ func TestResolveInputFlags_FileNotSupported(t *testing.T) { if err == nil { t.Fatal("expected error for file not supported") } + assertValidationParam(t, err, "--data") if !strings.Contains(err.Error(), "does not support file input") { t.Errorf("unexpected error: %v", err) } @@ -158,6 +160,7 @@ func TestResolveInputFlags_FileNotFound(t *testing.T) { if err == nil { t.Fatal("expected error for missing file") } + assertValidationParam(t, err, "--markdown") if !strings.Contains(err.Error(), "cannot read file") { t.Errorf("unexpected error: %v", err) } @@ -171,6 +174,7 @@ func TestResolveInputFlags_EmptyFilePath(t *testing.T) { if err == nil { t.Fatal("expected error for empty file path") } + assertValidationParam(t, err, "--markdown") if !strings.Contains(err.Error(), "file path cannot be empty after @") { t.Errorf("unexpected error: %v", err) } @@ -212,6 +216,7 @@ func TestResolveInputFlags_DuplicateStdin(t *testing.T) { if err == nil { t.Fatal("expected error for duplicate stdin usage") } + assertValidationParam(t, err, "--b") if !strings.Contains(err.Error(), "stdin (-) can only be used by one flag") { t.Errorf("unexpected error: %v", err) } diff --git a/shortcuts/common/runner_validation_test.go b/shortcuts/common/runner_validation_test.go new file mode 100644 index 00000000..c0430018 --- /dev/null +++ b/shortcuts/common/runner_validation_test.go @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import "testing" + +func TestValidateEnumFlags_ReturnsTypedValidation(t *testing.T) { + rctx := newTestRuntime(map[string]string{"mode": "delete"}) + err := validateEnumFlags(rctx, []Flag{ + {Name: "mode", Enum: []string{"append", "overwrite"}}, + }) + assertValidationParam(t, err, "--mode") +} + +func TestHandleShortcutDryRunUnsupported_ReturnsTypedValidation(t *testing.T) { + err := handleShortcutDryRun(nil, nil, &Shortcut{ + Service: "doc", + Command: "fetch", + }) + assertValidationParam(t, err, "--dry-run") +} diff --git a/shortcuts/common/userids.go b/shortcuts/common/userids.go index a43c12b1..4ce8829e 100644 --- a/shortcuts/common/userids.go +++ b/shortcuts/common/userids.go @@ -4,6 +4,7 @@ package common import ( + "fmt" "strings" "github.com/larksuite/cli/internal/output" @@ -13,9 +14,32 @@ import ( // open_id, removes duplicates case-insensitively while preserving the // first-occurrence form, and returns nil for an empty input. flagName is // used in error messages to point the user at the offending CLI flag. +// +// Deprecated: use ResolveOpenIDsTyped for typed error envelopes. func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]string, error) { + out, msg := resolveOpenIDs(flagName, ids, runtime) + if msg != "" { + return nil, output.ErrValidation("%s", msg) + } + return out, nil +} + +// ResolveOpenIDsTyped expands the special identifier "me" to the current +// user's open_id, removes duplicates case-insensitively while preserving the +// first-occurrence form, and returns nil for an empty input. flagName names +// the flag being resolved (e.g. "--user-ids") and is recorded on the typed +// error. +func ResolveOpenIDsTyped(flagName string, ids []string, runtime *RuntimeContext) ([]string, error) { + out, msg := resolveOpenIDs(flagName, ids, runtime) + if msg != "" { + return nil, ValidationErrorf("%s", msg).WithParam(flagName) + } + return out, nil +} + +func resolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]string, string) { if len(ids) == 0 { - return nil, nil + return nil, "" } currentUserID := runtime.UserOpenId() seen := make(map[string]struct{}, len(ids)) @@ -23,7 +47,7 @@ func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]s for _, id := range ids { if strings.EqualFold(id, "me") { if currentUserID == "" { - return nil, output.ErrValidation("%s: \"me\" requires a logged-in user with a resolvable open_id", flagName) + return nil, fmt.Sprintf("%s: \"me\" requires a logged-in user with a resolvable open_id", flagName) } id = currentUserID } @@ -34,5 +58,5 @@ func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]s seen[key] = struct{}{} out = append(out, id) } - return out, nil + return out, "" } diff --git a/shortcuts/common/userids_test.go b/shortcuts/common/userids_test.go index bd9febdd..7f60c29a 100644 --- a/shortcuts/common/userids_test.go +++ b/shortcuts/common/userids_test.go @@ -75,3 +75,24 @@ func TestResolveOpenIDs_DedupIsCaseInsensitive(t *testing.T) { t.Fatalf("case-insensitive dedup failed: got %v, want [ou_abc123]", out) } } + +func TestResolveOpenIDsTyped_MeWithoutLogin_ReturnsTypedValidation(t *testing.T) { + rt := resolveOpenIDsTestRuntime("") + _, err := ResolveOpenIDsTyped("--user-ids", []string{"me"}, rt) + validationErr := assertValidationParam(t, err, "--user-ids") + if !strings.Contains(validationErr.Message, "--user-ids") { + t.Fatalf("error should mention the offending flag name; got: %v", err) + } +} + +func TestResolveOpenIDsTyped_ExpandsMeAndDedups(t *testing.T) { + rt := resolveOpenIDsTestRuntime("ou_self") + out, err := ResolveOpenIDsTyped("--user-ids", []string{"me", "ou_a", "me", "ou_a"}, rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := []string{"ou_self", "ou_a"} + if len(out) != len(want) || out[0] != want[0] || out[1] != want[1] { + t.Fatalf("got %v, want %v", out, want) + } +} diff --git a/shortcuts/common/validate.go b/shortcuts/common/validate.go index 39bbf21b..3e4c4501 100644 --- a/shortcuts/common/validate.go +++ b/shortcuts/common/validate.go @@ -8,16 +8,26 @@ import ( "strconv" "strings" + "github.com/larksuite/cli/errs" "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/output" ) // FlagErrorf returns a validation error with flag context (exit code 2). +// +// Deprecated: use ValidationErrorf for typed error envelopes. func FlagErrorf(format string, args ...any) error { return output.ErrValidation(format, args...) } +// ValidationErrorf returns a typed validation error with invalid_argument subtype. +func ValidationErrorf(format string, args ...any) *errs.ValidationError { + return errs.NewValidationError(errs.SubtypeInvalidArgument, format, args...) +} + // MutuallyExclusive checks that at most one of the given flags is set. +// +// Deprecated: use MutuallyExclusiveTyped for typed error envelopes. func MutuallyExclusive(rt *RuntimeContext, flags ...string) error { var set []string for _, f := range flags { @@ -32,7 +42,25 @@ func MutuallyExclusive(rt *RuntimeContext, flags ...string) error { return nil } +// MutuallyExclusiveTyped checks that at most one of the given flags is set. +func MutuallyExclusiveTyped(rt *RuntimeContext, flags ...string) error { + var set []string + for _, f := range flags { + val := rt.Str(f) + if val != "" { + set = append(set, "--"+f) + } + } + if len(set) > 1 { + return ValidationErrorf("%s are mutually exclusive", strings.Join(set, " and ")). + WithParams(invalidParams(set, "mutually exclusive")...) + } + return nil +} + // AtLeastOne checks that at least one of the given flags is set. +// +// Deprecated: use AtLeastOneTyped for typed error envelopes. func AtLeastOne(rt *RuntimeContext, flags ...string) error { for _, f := range flags { if rt.Str(f) != "" { @@ -46,7 +74,24 @@ func AtLeastOne(rt *RuntimeContext, flags ...string) error { return FlagErrorf("specify at least one of %s", strings.Join(names, " or ")) } +// AtLeastOneTyped checks that at least one of the given flags is set. +func AtLeastOneTyped(rt *RuntimeContext, flags ...string) error { + for _, f := range flags { + if rt.Str(f) != "" { + return nil + } + } + names := make([]string, len(flags)) + for i, f := range flags { + names[i] = "--" + f + } + return ValidationErrorf("specify at least one of %s", strings.Join(names, " or ")). + WithParams(invalidParams(names, "required; specify at least one")...) +} + // ExactlyOne checks that exactly one of the given flags is set. +// +// Deprecated: use ExactlyOneTyped for typed error envelopes. func ExactlyOne(rt *RuntimeContext, flags ...string) error { if err := AtLeastOne(rt, flags...); err != nil { return err @@ -54,8 +99,18 @@ func ExactlyOne(rt *RuntimeContext, flags ...string) error { return MutuallyExclusive(rt, flags...) } +// ExactlyOneTyped checks that exactly one of the given flags is set. +func ExactlyOneTyped(rt *RuntimeContext, flags ...string) error { + if err := AtLeastOneTyped(rt, flags...); err != nil { + return err + } + return MutuallyExclusiveTyped(rt, flags...) +} + // ValidatePageSize validates that the named flag (if set) is an integer within [minVal, maxVal]. // It returns the parsed value (or defaultVal if the flag is empty) and any validation error. +// +// Deprecated: use ValidatePageSizeTyped for typed error envelopes. func ValidatePageSize(rt *RuntimeContext, flagName string, defaultVal, minVal, maxVal int) (int, error) { s := rt.Str(flagName) if s == "" { @@ -71,6 +126,25 @@ func ValidatePageSize(rt *RuntimeContext, flagName string, defaultVal, minVal, m return n, nil } +// ValidatePageSizeTyped validates that the named flag (if set) is an integer within [minVal, maxVal]. +// It returns the parsed value (or defaultVal if the flag is empty) and any validation error. +func ValidatePageSizeTyped(rt *RuntimeContext, flagName string, defaultVal, minVal, maxVal int) (int, error) { + s := rt.Str(flagName) + param := "--" + flagName + if s == "" { + return defaultVal, nil + } + n, err := strconv.Atoi(s) + if err != nil { + return 0, ValidationErrorf("invalid --%s %q: must be an integer", flagName, s).WithParam(param) + } + if n < minVal || n > maxVal { + return 0, ValidationErrorf("invalid --%s %d: must be between %d and %d", flagName, n, minVal, maxVal). + WithParam(param) + } + return n, nil +} + // ParseIntBounded parses an int flag and clamps it to [min, max]. func ParseIntBounded(rt *RuntimeContext, name string, min, max int) int { v := rt.Int(name) @@ -87,13 +161,26 @@ func ParseIntBounded(rt *RuntimeContext, name string, min, max int) int { // working directory. It catches traversal, symlink escape, and control // characters by delegating to FileIO.ResolvePath. Works for both file and // directory paths. +// +// Deprecated: use ValidateSafePathTyped for typed error envelopes. func ValidateSafePath(fio fileio.FileIO, path string) error { _, err := fio.ResolvePath(path) return err } +// ValidateSafePathTyped ensures path resolves within the current working directory. +func ValidateSafePathTyped(fio fileio.FileIO, path string) error { + _, err := fio.ResolvePath(path) + if err != nil { + return ValidationErrorf("%s", err).WithCause(err) + } + return nil +} + // RejectDangerousChars returns an error if value contains ASCII control // characters or dangerous Unicode code points. +// +// Deprecated: use RejectDangerousCharsTyped for typed error envelopes. func RejectDangerousChars(paramName, value string) error { for _, r := range value { if r < 0x20 && r != '\t' && r != '\n' { @@ -108,3 +195,31 @@ func RejectDangerousChars(paramName, value string) error { } return nil } + +// RejectDangerousCharsTyped returns an error if value contains ASCII control +// characters or dangerous Unicode code points. +func RejectDangerousCharsTyped(paramName, value string) error { + for _, r := range value { + if r < 0x20 && r != '\t' && r != '\n' { + return ValidationErrorf("parameter %q contains control character U+%04X", paramName, r). + WithParam(paramName) + } + if r == 0x7F { + return ValidationErrorf("parameter %q contains DEL character", paramName). + WithParam(paramName) + } + if IsDangerousUnicode(r) { + return ValidationErrorf("parameter %q contains dangerous Unicode character U+%04X", paramName, r). + WithParam(paramName) + } + } + return nil +} + +func invalidParams(names []string, reason string) []errs.InvalidParam { + params := make([]errs.InvalidParam, len(names)) + for i, name := range names { + params[i] = errs.InvalidParam{Name: name, Reason: reason} + } + return params +} diff --git a/shortcuts/common/validate_ids.go b/shortcuts/common/validate_ids.go index 50697087..69914789 100644 --- a/shortcuts/common/validate_ids.go +++ b/shortcuts/common/validate_ids.go @@ -11,10 +11,31 @@ import ( // ValidateChatID checks if a chat ID has valid format (oc_ prefix). // Also extracts token from URL if provided. +// +// Deprecated: use ValidateChatIDTyped for typed error envelopes. func ValidateChatID(input string) (string, error) { + chatID, msg := normalizeChatID(input) + if msg != "" { + return "", output.ErrValidation("%s", msg) + } + return chatID, nil +} + +// ValidateChatIDTyped checks if a chat ID has valid format (oc_ prefix). +// Also extracts token from URL if provided. param names the flag being +// validated (e.g. "--chat-ids") and is recorded on the typed error. +func ValidateChatIDTyped(param, input string) (string, error) { + chatID, msg := normalizeChatID(input) + if msg != "" { + return "", ValidationErrorf("%s", msg).WithParam(param) + } + return chatID, nil +} + +func normalizeChatID(input string) (string, string) { input = strings.TrimSpace(input) if input == "" { - return "", output.ErrValidation("chat ID cannot be empty") + return "", "chat ID cannot be empty" } // Extract from URL if present if strings.Contains(input, "feishu.cn") || strings.Contains(input, "larksuite.com") { @@ -28,19 +49,40 @@ func ValidateChatID(input string) (string, error) { } } if !strings.HasPrefix(input, "oc_") { - return "", output.ErrValidation("invalid chat ID format, should start with 'oc_' (e.g., oc_abc123)") + return "", "invalid chat ID format, should start with 'oc_' (e.g., oc_abc123)" } - return input, nil + return input, "" } // ValidateUserID checks if a user ID has valid format (ou_ prefix). +// +// Deprecated: use ValidateUserIDTyped for typed error envelopes. func ValidateUserID(input string) (string, error) { + userID, msg := normalizeUserID(input) + if msg != "" { + return "", output.ErrValidation("%s", msg) + } + return userID, nil +} + +// ValidateUserIDTyped checks if a user ID has valid format (ou_ prefix). +// param names the flag being validated (e.g. "--creator-ids") and is +// recorded on the typed error. +func ValidateUserIDTyped(param, input string) (string, error) { + userID, msg := normalizeUserID(input) + if msg != "" { + return "", ValidationErrorf("%s", msg).WithParam(param) + } + return userID, nil +} + +func normalizeUserID(input string) (string, string) { input = strings.TrimSpace(input) if input == "" { - return "", output.ErrValidation("user ID cannot be empty") + return "", "user ID cannot be empty" } if !strings.HasPrefix(input, "ou_") { - return "", output.ErrValidation("invalid user ID format, should start with 'ou_' (e.g., ou_abc123)") + return "", "invalid user ID format, should start with 'ou_' (e.g., ou_abc123)" } - return input, nil + return input, "" } diff --git a/shortcuts/common/validate_test.go b/shortcuts/common/validate_test.go index 89eb72e8..4a58e9dd 100644 --- a/shortcuts/common/validate_test.go +++ b/shortcuts/common/validate_test.go @@ -4,10 +4,14 @@ package common import ( + "errors" "os" "path/filepath" + "strings" "testing" + "github.com/larksuite/cli/errs" + "github.com/larksuite/cli/extension/fileio" "github.com/larksuite/cli/internal/vfs/localfileio" "github.com/spf13/cobra" ) @@ -26,6 +30,24 @@ func newTestRuntime(flags map[string]string) *RuntimeContext { return &RuntimeContext{Cmd: cmd} } +func assertValidationParam(t *testing.T, err error, param string) *errs.ValidationError { + t.Helper() + if err == nil { + t.Fatal("expected validation error, got nil") + } + var validationErr *errs.ValidationError + if !errors.As(err, &validationErr) { + t.Fatalf("expected *errs.ValidationError, got %T: %v", err, err) + } + if validationErr.Subtype != errs.SubtypeInvalidArgument { + t.Fatalf("Subtype = %q, want %q", validationErr.Subtype, errs.SubtypeInvalidArgument) + } + if param != "" && validationErr.Param != param { + t.Fatalf("Param = %q, want %q", validationErr.Param, param) + } + return validationErr +} + func TestMutuallyExclusive(t *testing.T) { tests := []struct { name string @@ -69,6 +91,109 @@ func TestMutuallyExclusive(t *testing.T) { } } +func TestValidationErrorf_ReturnsTypedInvalidArgument(t *testing.T) { + err := ValidationErrorf("bad %s", "flag") + validationErr := assertValidationParam(t, err, "") + if validationErr.Message != "bad flag" { + t.Fatalf("Message = %q, want %q", validationErr.Message, "bad flag") + } +} + +func TestTypedFlagGroupHelpers_ReturnValidationParams(t *testing.T) { + t.Run("mutually exclusive", func(t *testing.T) { + rt := newTestRuntime(map[string]string{"a": "x", "b": "y"}) + validationErr := assertValidationParam(t, MutuallyExclusiveTyped(rt, "a", "b"), "") + if len(validationErr.Params) != 2 { + t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params) + } + if validationErr.Params[0].Name != "--a" || validationErr.Params[1].Name != "--b" { + t.Fatalf("Params names = %+v, want --a/--b", validationErr.Params) + } + }) + + t.Run("at least one", func(t *testing.T) { + rt := newTestRuntime(map[string]string{"a": "", "b": ""}) + validationErr := assertValidationParam(t, AtLeastOneTyped(rt, "a", "b"), "") + if len(validationErr.Params) != 2 { + t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params) + } + if !strings.Contains(validationErr.Message, "--a or --b") { + t.Fatalf("Message = %q, want flag group", validationErr.Message) + } + }) + + t.Run("exactly one", func(t *testing.T) { + rt := newTestRuntime(map[string]string{"a": "x", "b": "y"}) + validationErr := assertValidationParam(t, ExactlyOneTyped(rt, "a", "b"), "") + if len(validationErr.Params) != 2 { + t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params) + } + }) +} + +func TestValidatePageSizeTyped_ReturnsTypedValidation(t *testing.T) { + rt := newTestRuntime(map[string]string{"page-size": "nope"}) + _, err := ValidatePageSizeTyped(rt, "page-size", 10, 1, 20) + assertValidationParam(t, err, "--page-size") + + rt = newTestRuntime(map[string]string{"page-size": "30"}) + _, err = ValidatePageSizeTyped(rt, "page-size", 10, 1, 20) + assertValidationParam(t, err, "--page-size") +} + +func TestValidateIDTyped_ReturnsTypedValidation(t *testing.T) { + chatID, err := ValidateChatIDTyped("--chat-ids", "https://example.feishu.cn/foo/oc_abc") + if err != nil { + t.Fatalf("ValidateChatIDTyped valid URL: %v", err) + } + if chatID != "oc_abc" { + t.Fatalf("chatID = %q, want oc_abc", chatID) + } + assertValidationParam(t, func() error { + _, err := ValidateChatIDTyped("--chat-ids", "bad") + return err + }(), "--chat-ids") + assertValidationParam(t, func() error { + _, err := ValidateUserIDTyped("--creator-ids", "bad") + return err + }(), "--creator-ids") +} + +func TestRejectDangerousCharsTyped_ReturnsTypedValidation(t *testing.T) { + err := RejectDangerousCharsTyped("--query", "bad\x01") + validationErr := assertValidationParam(t, err, "--query") + if !strings.Contains(validationErr.Message, "control character") { + t.Fatalf("Message = %q, want control character", validationErr.Message) + } +} + +func TestWrapInputStatErrorTyped_ReturnsTypedValidation(t *testing.T) { + cause := &fileio.PathValidationError{Err: errors.New("outside cwd")} + err := WrapInputStatErrorTyped(cause) + validationErr := assertValidationParam(t, err, "") + if !strings.Contains(validationErr.Message, "unsafe file path") { + t.Fatalf("Message = %q, want unsafe file path", validationErr.Message) + } + if !errors.Is(err, fileio.ErrPathValidation) { + t.Fatalf("expected errors.Is(fileio.ErrPathValidation) to match") + } +} + +func TestWrapSaveErrorTyped_ClassifiesPathAndFileIO(t *testing.T) { + pathErr := &fileio.PathValidationError{Err: errors.New("outside cwd")} + assertValidationParam(t, WrapSaveErrorTyped(pathErr), "") + + mkdirErr := &fileio.MkdirError{Err: errors.New("permission denied")} + err := WrapSaveErrorTyped(mkdirErr) + var internalErr *errs.InternalError + if !errors.As(err, &internalErr) { + t.Fatalf("expected *errs.InternalError, got %T: %v", err, err) + } + if internalErr.Subtype != errs.SubtypeFileIO { + t.Fatalf("Subtype = %q, want %q", internalErr.Subtype, errs.SubtypeFileIO) + } +} + func TestAtLeastOne(t *testing.T) { tests := []struct { name string @@ -246,3 +371,20 @@ func TestValidateSafePath_AllowsNonExistentPath(t *testing.T) { t.Fatalf("expected no error for non-existent path, got: %v", err) } } + +// TestValidateSafePathTyped_ReturnsTypedValidation verifies that an escaping +// path is rejected with a typed validation error and a safe path passes. +func TestValidateSafePathTyped_ReturnsTypedValidation(t *testing.T) { + outside := t.TempDir() + workDir := t.TempDir() + chdirForTest(t, workDir) + + if err := os.Symlink(outside, filepath.Join(workDir, "evil_out")); err != nil { + t.Fatalf("Symlink: %v", err) + } + assertValidationParam(t, ValidateSafePathTyped(&localfileio.LocalFileIO{}, "evil_out"), "") + + if err := ValidateSafePathTyped(&localfileio.LocalFileIO{}, "new_output_dir"); err != nil { + t.Fatalf("expected no error for safe path, got: %v", err) + } +} diff --git a/shortcuts/drive/drive_search.go b/shortcuts/drive/drive_search.go index a1787251..e50391e4 100644 --- a/shortcuts/drive/drive_search.go +++ b/shortcuts/drive/drive_search.go @@ -354,7 +354,7 @@ func parseDriveSearchPageSize(raw string) (int, error) { // server-side failure or empty result. func validateDriveSearchIDs(spec driveSearchSpec) error { for _, id := range spec.CreatorIDs { - if _, err := common.ValidateUserID(id); err != nil { + if _, err := common.ValidateUserIDTyped("--creator-ids", id); err != nil { return errs.NewValidationError(errs.SubtypeInvalidArgument, "--creator-ids %q: %s", id, err).WithParam("--creator-ids") } } @@ -362,7 +362,7 @@ func validateDriveSearchIDs(spec driveSearchSpec) error { return errs.NewValidationError(errs.SubtypeInvalidArgument, "--chat-ids: max %d values per request, got %d", driveSearchMaxChatIDs, n).WithParam("--chat-ids") } for _, id := range spec.ChatIDs { - if _, err := common.ValidateChatID(id); err != nil { + if _, err := common.ValidateChatIDTyped("--chat-ids", id); err != nil { return errs.NewValidationError(errs.SubtypeInvalidArgument, "--chat-ids %q: %s", id, err).WithParam("--chat-ids") } } @@ -370,7 +370,7 @@ func validateDriveSearchIDs(spec driveSearchSpec) error { return errs.NewValidationError(errs.SubtypeInvalidArgument, "--sharer-ids: max %d values per request, got %d", driveSearchMaxSharerIDs, n).WithParam("--sharer-ids") } for _, id := range spec.SharerIDs { - if _, err := common.ValidateUserID(id); err != nil { + if _, err := common.ValidateUserIDTyped("--sharer-ids", id); err != nil { return errs.NewValidationError(errs.SubtypeInvalidArgument, "--sharer-ids %q: %s", id, err).WithParam("--sharer-ids") } }