feat(common): emit typed validation errors from shared shortcut pre-checks (#1242)

Input pre-check failures shared by every shortcut — @file/stdin input
resolution, enum validation, and unsupported --dry-run — now leave the
CLI as typed validation envelopes naming the offending flag, so scripts
and AI agents can branch on `param` instead of parsing prose. Wire type,
exit code, and message text are unchanged; the new fields are additive.

The shared layer also gains typed replacements for its legacy
error-producing helpers, so each business domain can migrate to typed
errors without rebuilding common plumbing, and a path-scoped lint guard
keeps migrated domains from sliding back.

Changes:
- Shared pre-check failures (input flags, enum values, dry-run support)
  return typed validation errors carrying the offending flag as `param`.
- Every legacy error-producing helper in shortcuts/common has a typed
  replacement that preserves the existing message text: validation and
  flag-group checks, chat/user ID validation (callers name the flag so
  `param` is ground truth), "me" open-id resolution, safe-path checks,
  input-stat and save-error wrapping. Legacy helpers stay for
  not-yet-migrated domains, marked deprecated — including the legacy
  API-result classifier, whose typed route is runtime.CallAPITyped.
- A new errscontract rule rejects legacy common-helper calls on migrated
  paths, so a migrated domain cannot silently reintroduce legacy
  envelopes; drive is the first locked path and its last legacy
  ID-helper calls are replaced.
This commit is contained in:
evandance
2026-06-03 19:20:19 +08:00
committed by GitHub
parent 2f35ce3724
commit b3fcf55611
13 changed files with 708 additions and 21 deletions

View File

@@ -0,0 +1,138 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package errscontract
import (
"go/ast"
"go/parser"
"go/token"
"strings"
)
// migratedCommonHelperPaths lists source-tree prefixes whose command validation
// has migrated to typed errs.* envelopes. On these paths, calls to common's
// legacy validation/save helpers are forbidden; callers must use the typed
// common replacements or construct an errs.* typed error directly.
var migratedCommonHelperPaths = []string{
"shortcuts/drive/",
}
const commonImportPath = "github.com/larksuite/cli/shortcuts/common"
var legacyCommonHelperReplacements = map[string]string{
"FlagErrorf": "common.ValidationErrorf",
"MutuallyExclusive": "common.MutuallyExclusiveTyped",
"AtLeastOne": "common.AtLeastOneTyped",
"ExactlyOne": "common.ExactlyOneTyped",
"ValidatePageSize": "common.ValidatePageSizeTyped",
"ValidateChatID": "common.ValidateChatIDTyped",
"ValidateUserID": "common.ValidateUserIDTyped",
"ValidateSafePath": "common.ValidateSafePathTyped",
"RejectDangerousChars": "common.RejectDangerousCharsTyped",
"WrapInputStatError": "common.WrapInputStatErrorTyped",
"WrapSaveErrorByCategory": "common.WrapSaveErrorTyped",
"ResolveOpenIDs": "common.ResolveOpenIDsTyped",
"HandleApiResult": "runtime.CallAPITyped",
}
// CheckNoLegacyCommonHelperCall flags any reference to common's legacy helper
// APIs on migrated paths — direct calls and function-value references alike,
// so `f := common.FlagErrorf; f(...)` cannot slip past the guard. These
// helpers return legacy output envelopes or bare errors, so migrated domains
// should use their typed-aware replacements.
func CheckNoLegacyCommonHelperCall(path, src string) []Violation {
if !isMigratedCommonHelperPath(path) || strings.HasSuffix(path, "_test.go") {
return nil
}
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, path, src, parser.ParseComments)
if err != nil {
return nil
}
localNames, dotImported := resolveCommonNames(file)
var out []Violation
report := func(pos token.Pos, name, replacement string) {
out = append(out, Violation{
Rule: "no_legacy_common_helper_call",
Action: ActionReject,
File: path,
Line: fset.Position(pos).Line,
Message: "common." + name + " returns a legacy error shape and is forbidden on migrated paths",
Suggestion: "replace common." + name + " with " + replacement + " or a typed errs.* constructor",
})
}
// Pass 1: qualified references (common.X / alias.X). Record every
// selector field so the dot-import pass below never mistakes another
// package's same-named field for a common helper.
selFields := make(map[*ast.Ident]struct{})
ast.Inspect(file, func(n ast.Node) bool {
sel, ok := n.(*ast.SelectorExpr)
if !ok {
return true
}
selFields[sel.Sel] = struct{}{}
x, ok := sel.X.(*ast.Ident)
if !ok {
return true
}
if _, bound := localNames[x.Name]; !bound {
return true
}
if replacement, ok := legacyCommonHelperReplacements[sel.Sel.Name]; ok {
report(sel.Pos(), sel.Sel.Name, replacement)
}
return true
})
// Pass 2: unqualified references under a dot import.
if dotImported {
ast.Inspect(file, func(n ast.Node) bool {
ident, ok := n.(*ast.Ident)
if !ok {
return true
}
if _, isField := selFields[ident]; isField {
return true
}
if replacement, ok := legacyCommonHelperReplacements[ident.Name]; ok {
report(ident.Pos(), ident.Name, replacement)
}
return true
})
}
return out
}
func isMigratedCommonHelperPath(path string) bool {
p := strings.ReplaceAll(path, "\\", "/")
for _, prefix := range migratedCommonHelperPaths {
if strings.HasPrefix(p, prefix) || strings.Contains(p, "/"+prefix) {
return true
}
}
return false
}
func resolveCommonNames(file *ast.File) (map[string]struct{}, bool) {
names := make(map[string]struct{})
dotImported := false
for _, imp := range file.Imports {
if imp.Path == nil {
continue
}
p := strings.Trim(imp.Path.Value, "`\"")
if p != commonImportPath {
continue
}
switch {
case imp.Name == nil:
names["common"] = struct{}{}
case imp.Name.Name == ".":
dotImported = true
case imp.Name.Name == "_":
default:
names[imp.Name.Name] = struct{}{}
}
}
return names, dotImported
}

View File

@@ -877,3 +877,123 @@ func boom(runtime *common.RuntimeContext) error {
t.Errorf("test files must be skipped, got: %+v", v)
}
}
func TestCheckNoLegacyCommonHelperCall_RejectsLegacyHelpersOnMigratedPath(t *testing.T) {
helpers := []string{
"FlagErrorf",
"MutuallyExclusive",
"AtLeastOne",
"ExactlyOne",
"ValidatePageSize",
"ValidateChatID",
"ValidateUserID",
"ValidateSafePath",
"RejectDangerousChars",
"WrapInputStatError",
"WrapSaveErrorByCategory",
"ResolveOpenIDs",
"HandleApiResult",
}
for _, helper := range helpers {
t.Run(helper, func(t *testing.T) {
src := `package drive
import "github.com/larksuite/cli/shortcuts/common"
func boom() {
common.` + helper + `()
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src)
if len(v) != 1 {
t.Fatalf("expected 1 violation for %s, got %d: %+v", helper, len(v), v)
}
if v[0].Action != ActionReject {
t.Errorf("action = %q, want REJECT", v[0].Action)
}
if !strings.Contains(v[0].Message, "common."+helper) {
t.Errorf("message should name helper %s: %s", helper, v[0].Message)
}
})
}
}
func TestCheckNoLegacyCommonHelperCall_AllowsNonMigratedPath(t *testing.T) {
src := `package im
import "github.com/larksuite/cli/shortcuts/common"
func boom() {
common.FlagErrorf("legacy allowed until domain migrates")
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/im/im_send.go", src)
if len(v) != 0 {
t.Errorf("non-migrated path must pass, got: %+v", v)
}
}
func TestCheckNoLegacyCommonHelperCall_AllowsTypedHelpersOnMigratedPath(t *testing.T) {
src := `package drive
import "github.com/larksuite/cli/shortcuts/common"
func boom() {
common.ValidationErrorf("typed")
common.MutuallyExclusiveTyped(nil, "a", "b")
common.ValidateChatIDTyped("--chat-ids", "oc_abc")
common.ResolveOpenIDsTyped("--user-ids", nil, nil)
common.WrapSaveErrorTyped(nil)
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src)
if len(v) != 0 {
t.Errorf("typed helpers must pass, got: %+v", v)
}
}
func TestCheckNoLegacyCommonHelperCall_RejectsAliasedImport(t *testing.T) {
src := `package drive
import c "github.com/larksuite/cli/shortcuts/common"
func boom() {
c.FlagErrorf("legacy")
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src)
if len(v) != 1 {
t.Fatalf("expected 1 violation for aliased common import, got %d: %+v", len(v), v)
}
}
func TestCheckNoLegacyCommonHelperCall_RejectsDotImport(t *testing.T) {
src := `package drive
import . "github.com/larksuite/cli/shortcuts/common"
func boom() {
FlagErrorf("legacy")
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src)
if len(v) != 1 {
t.Fatalf("expected 1 violation for dot-imported common, got %d: %+v", len(v), v)
}
}
func TestCheckNoLegacyCommonHelperCall_RejectsFunctionValueReference(t *testing.T) {
src := `package drive
import "github.com/larksuite/cli/shortcuts/common"
func boom() error {
f := common.FlagErrorf
return f("legacy")
}
`
v := CheckNoLegacyCommonHelperCall("shortcuts/drive/drive_search.go", src)
if len(v) != 1 {
t.Fatalf("expected 1 violation for function-value reference, got %d: %+v", len(v), v)
}
}

View File

@@ -108,6 +108,7 @@ func ScanRepo(root string) ([]Violation, error) {
all = append(all, CheckTypedErrorCompleteness(rel, string(src))...)
all = append(all, CheckNoLegacyEnvelopeLiteral(rel, string(src))...)
all = append(all, CheckNoLegacyRuntimeAPICall(rel, string(src))...)
all = append(all, CheckNoLegacyCommonHelperCall(rel, string(src))...)
// Typed-error invariants — self-scope to errs/ + classify.go.
all = append(all, CheckNilSafeError(rel, string(src))...)
all = append(all, CheckUnwrapSymmetry(rel, string(src))...)

View File

@@ -164,6 +164,9 @@ func CheckApiError(w io.Writer, result interface{}, action string) bool {
}
// HandleApiResult checks for network/API errors and returns the "data" field.
//
// Deprecated: use RuntimeContext.CallAPITyped (or ClassifyAPIResponse for
// self-driven requests) for typed error envelopes.
func HandleApiResult(result interface{}, err error, action string) (map[string]interface{}, error) {
if err != nil {
return nil, output.Errorf(output.ExitAPI, "api_error", "%s: %s", action, err)

View File

@@ -625,6 +625,8 @@ func WrapOpenError(err error, pathMsg, readMsg string) error {
// - Other errors → readMsg prefix (default "cannot read file")
//
// Pass an optional readMsg to override the non-path-validation message prefix.
//
// Deprecated: use WrapInputStatErrorTyped for typed error envelopes.
func WrapInputStatError(err error, readMsg ...string) error {
if err == nil {
return nil
@@ -639,9 +641,28 @@ func WrapInputStatError(err error, readMsg ...string) error {
return output.ErrValidation("%s: %s", msg, err)
}
// WrapInputStatErrorTyped wraps a FileIO.Stat/Open error for input file validation.
func WrapInputStatErrorTyped(err error, readMsg ...string) error {
if err == nil {
return nil
}
if errors.Is(err, fileio.ErrPathValidation) {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "unsafe file path: %s", err).
WithCause(err)
}
msg := "cannot read file"
if len(readMsg) > 0 && readMsg[0] != "" {
msg = readMsg[0]
}
return errs.NewValidationError(errs.SubtypeInvalidArgument, "%s: %s", msg, err).
WithCause(err)
}
// WrapSaveErrorByCategory maps a FileIO.Save error to structured output errors,
// using standardized messages and the given error category (e.g. "api_error", "io").
// Path validation errors always use ErrValidation (exit code 2).
//
// Deprecated: use WrapSaveErrorTyped for typed error envelopes.
func WrapSaveErrorByCategory(err error, category string) error {
if err == nil {
return nil
@@ -657,6 +678,28 @@ func WrapSaveErrorByCategory(err error, category string) error {
}
}
// WrapSaveErrorTyped maps a FileIO.Save error to typed validation/internal errors.
// Unlike WrapSaveErrorByCategory, non-path failures always emit the canonical
// "internal" wire type: call sites migrating from a custom category
// (e.g. "io", "api_error") change their envelope's type field.
func WrapSaveErrorTyped(err error) error {
if err == nil {
return nil
}
var me *fileio.MkdirError
switch {
case errors.Is(err, fileio.ErrPathValidation):
return errs.NewValidationError(errs.SubtypeInvalidArgument, "unsafe output path: %s", err).
WithCause(err)
case errors.As(err, &me):
return errs.NewInternalError(errs.SubtypeFileIO, "cannot create parent directory: %s", err).
WithCause(err)
default:
return errs.NewInternalError(errs.SubtypeFileIO, "cannot create file: %s", err).
WithCause(err)
}
}
// ValidatePath checks that path is a valid relative input path within the
// working directory by delegating to FileIO.Stat. Returns nil if the path is
// valid or does not exist yet; returns an error only for illegal paths
@@ -1022,7 +1065,8 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error {
}
raw, err := rctx.Cmd.Flags().GetString(fl.Name)
if err != nil {
return FlagErrorf("--%s: Input is only supported for string flags", fl.Name)
return ValidationErrorf("--%s: Input is only supported for string flags", fl.Name).
WithParam("--" + fl.Name)
}
if raw == "" {
continue
@@ -1031,15 +1075,19 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error {
// stdin: -
if raw == "-" {
if !slices.Contains(fl.Input, Stdin) {
return FlagErrorf("--%s does not support stdin (-)", fl.Name)
return ValidationErrorf("--%s does not support stdin (-)", fl.Name).
WithParam("--" + fl.Name)
}
if stdinUsed {
return FlagErrorf("--%s: stdin (-) can only be used by one flag", fl.Name)
return ValidationErrorf("--%s: stdin (-) can only be used by one flag", fl.Name).
WithParam("--" + fl.Name)
}
stdinUsed = true
data, err := io.ReadAll(rctx.IO().In)
if err != nil {
return FlagErrorf("--%s: failed to read from stdin: %v", fl.Name, err)
return ValidationErrorf("--%s: failed to read from stdin: %v", fl.Name, err).
WithParam("--" + fl.Name).
WithCause(err)
}
rctx.Cmd.Flags().Set(fl.Name, string(data))
continue
@@ -1054,15 +1102,19 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error {
// file: @path
if strings.HasPrefix(raw, "@") {
if !slices.Contains(fl.Input, File) {
return FlagErrorf("--%s does not support file input (@path)", fl.Name)
return ValidationErrorf("--%s does not support file input (@path)", fl.Name).
WithParam("--" + fl.Name)
}
path := strings.TrimSpace(raw[1:])
if path == "" {
return FlagErrorf("--%s: file path cannot be empty after @", fl.Name)
return ValidationErrorf("--%s: file path cannot be empty after @", fl.Name).
WithParam("--" + fl.Name)
}
data, err := cmdutil.ReadInputFile(rctx.FileIO(), path)
if err != nil {
return FlagErrorf("--%s: %v", fl.Name, err)
return ValidationErrorf("--%s: %v", fl.Name, err).
WithParam("--" + fl.Name).
WithCause(err)
}
rctx.Cmd.Flags().Set(fl.Name, string(data))
continue
@@ -1088,7 +1140,8 @@ func validateEnumFlags(rctx *RuntimeContext, flags []Flag) error {
}
}
if !valid {
return FlagErrorf("invalid value %q for --%s, allowed: %s", val, fl.Name, strings.Join(fl.Enum, ", "))
return ValidationErrorf("invalid value %q for --%s, allowed: %s", val, fl.Name, strings.Join(fl.Enum, ", ")).
WithParam("--" + fl.Name)
}
}
return nil
@@ -1096,7 +1149,8 @@ func validateEnumFlags(rctx *RuntimeContext, flags []Flag) error {
func handleShortcutDryRun(f *cmdutil.Factory, rctx *RuntimeContext, s *Shortcut) error {
if s.DryRun == nil {
return FlagErrorf("--dry-run is not supported for %s %s", s.Service, s.Command)
return ValidationErrorf("--dry-run is not supported for %s %s", s.Service, s.Command).
WithParam("--dry-run")
}
fmt.Fprintln(f.IOStreams.ErrOut, "=== Dry Run ===")
dryResult := s.DryRun(rctx.ctx, rctx)

View File

@@ -129,6 +129,7 @@ func TestResolveInputFlags_StdinNotSupported(t *testing.T) {
if err == nil {
t.Fatal("expected error for stdin not supported")
}
assertValidationParam(t, err, "--data")
if !strings.Contains(err.Error(), "does not support stdin") {
t.Errorf("unexpected error: %v", err)
}
@@ -142,6 +143,7 @@ func TestResolveInputFlags_FileNotSupported(t *testing.T) {
if err == nil {
t.Fatal("expected error for file not supported")
}
assertValidationParam(t, err, "--data")
if !strings.Contains(err.Error(), "does not support file input") {
t.Errorf("unexpected error: %v", err)
}
@@ -158,6 +160,7 @@ func TestResolveInputFlags_FileNotFound(t *testing.T) {
if err == nil {
t.Fatal("expected error for missing file")
}
assertValidationParam(t, err, "--markdown")
if !strings.Contains(err.Error(), "cannot read file") {
t.Errorf("unexpected error: %v", err)
}
@@ -171,6 +174,7 @@ func TestResolveInputFlags_EmptyFilePath(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty file path")
}
assertValidationParam(t, err, "--markdown")
if !strings.Contains(err.Error(), "file path cannot be empty after @") {
t.Errorf("unexpected error: %v", err)
}
@@ -212,6 +216,7 @@ func TestResolveInputFlags_DuplicateStdin(t *testing.T) {
if err == nil {
t.Fatal("expected error for duplicate stdin usage")
}
assertValidationParam(t, err, "--b")
if !strings.Contains(err.Error(), "stdin (-) can only be used by one flag") {
t.Errorf("unexpected error: %v", err)
}

View File

@@ -0,0 +1,22 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package common
import "testing"
func TestValidateEnumFlags_ReturnsTypedValidation(t *testing.T) {
rctx := newTestRuntime(map[string]string{"mode": "delete"})
err := validateEnumFlags(rctx, []Flag{
{Name: "mode", Enum: []string{"append", "overwrite"}},
})
assertValidationParam(t, err, "--mode")
}
func TestHandleShortcutDryRunUnsupported_ReturnsTypedValidation(t *testing.T) {
err := handleShortcutDryRun(nil, nil, &Shortcut{
Service: "doc",
Command: "fetch",
})
assertValidationParam(t, err, "--dry-run")
}

View File

@@ -4,6 +4,7 @@
package common
import (
"fmt"
"strings"
"github.com/larksuite/cli/internal/output"
@@ -13,9 +14,32 @@ import (
// open_id, removes duplicates case-insensitively while preserving the
// first-occurrence form, and returns nil for an empty input. flagName is
// used in error messages to point the user at the offending CLI flag.
//
// Deprecated: use ResolveOpenIDsTyped for typed error envelopes.
func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]string, error) {
out, msg := resolveOpenIDs(flagName, ids, runtime)
if msg != "" {
return nil, output.ErrValidation("%s", msg)
}
return out, nil
}
// ResolveOpenIDsTyped expands the special identifier "me" to the current
// user's open_id, removes duplicates case-insensitively while preserving the
// first-occurrence form, and returns nil for an empty input. flagName names
// the flag being resolved (e.g. "--user-ids") and is recorded on the typed
// error.
func ResolveOpenIDsTyped(flagName string, ids []string, runtime *RuntimeContext) ([]string, error) {
out, msg := resolveOpenIDs(flagName, ids, runtime)
if msg != "" {
return nil, ValidationErrorf("%s", msg).WithParam(flagName)
}
return out, nil
}
func resolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]string, string) {
if len(ids) == 0 {
return nil, nil
return nil, ""
}
currentUserID := runtime.UserOpenId()
seen := make(map[string]struct{}, len(ids))
@@ -23,7 +47,7 @@ func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]s
for _, id := range ids {
if strings.EqualFold(id, "me") {
if currentUserID == "" {
return nil, output.ErrValidation("%s: \"me\" requires a logged-in user with a resolvable open_id", flagName)
return nil, fmt.Sprintf("%s: \"me\" requires a logged-in user with a resolvable open_id", flagName)
}
id = currentUserID
}
@@ -34,5 +58,5 @@ func ResolveOpenIDs(flagName string, ids []string, runtime *RuntimeContext) ([]s
seen[key] = struct{}{}
out = append(out, id)
}
return out, nil
return out, ""
}

View File

@@ -75,3 +75,24 @@ func TestResolveOpenIDs_DedupIsCaseInsensitive(t *testing.T) {
t.Fatalf("case-insensitive dedup failed: got %v, want [ou_abc123]", out)
}
}
func TestResolveOpenIDsTyped_MeWithoutLogin_ReturnsTypedValidation(t *testing.T) {
rt := resolveOpenIDsTestRuntime("")
_, err := ResolveOpenIDsTyped("--user-ids", []string{"me"}, rt)
validationErr := assertValidationParam(t, err, "--user-ids")
if !strings.Contains(validationErr.Message, "--user-ids") {
t.Fatalf("error should mention the offending flag name; got: %v", err)
}
}
func TestResolveOpenIDsTyped_ExpandsMeAndDedups(t *testing.T) {
rt := resolveOpenIDsTestRuntime("ou_self")
out, err := ResolveOpenIDsTyped("--user-ids", []string{"me", "ou_a", "me", "ou_a"}, rt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"ou_self", "ou_a"}
if len(out) != len(want) || out[0] != want[0] || out[1] != want[1] {
t.Fatalf("got %v, want %v", out, want)
}
}

View File

@@ -8,16 +8,26 @@ import (
"strconv"
"strings"
"github.com/larksuite/cli/errs"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
)
// FlagErrorf returns a validation error with flag context (exit code 2).
//
// Deprecated: use ValidationErrorf for typed error envelopes.
func FlagErrorf(format string, args ...any) error {
return output.ErrValidation(format, args...)
}
// ValidationErrorf returns a typed validation error with invalid_argument subtype.
func ValidationErrorf(format string, args ...any) *errs.ValidationError {
return errs.NewValidationError(errs.SubtypeInvalidArgument, format, args...)
}
// MutuallyExclusive checks that at most one of the given flags is set.
//
// Deprecated: use MutuallyExclusiveTyped for typed error envelopes.
func MutuallyExclusive(rt *RuntimeContext, flags ...string) error {
var set []string
for _, f := range flags {
@@ -32,7 +42,25 @@ func MutuallyExclusive(rt *RuntimeContext, flags ...string) error {
return nil
}
// MutuallyExclusiveTyped checks that at most one of the given flags is set.
func MutuallyExclusiveTyped(rt *RuntimeContext, flags ...string) error {
var set []string
for _, f := range flags {
val := rt.Str(f)
if val != "" {
set = append(set, "--"+f)
}
}
if len(set) > 1 {
return ValidationErrorf("%s are mutually exclusive", strings.Join(set, " and ")).
WithParams(invalidParams(set, "mutually exclusive")...)
}
return nil
}
// AtLeastOne checks that at least one of the given flags is set.
//
// Deprecated: use AtLeastOneTyped for typed error envelopes.
func AtLeastOne(rt *RuntimeContext, flags ...string) error {
for _, f := range flags {
if rt.Str(f) != "" {
@@ -46,7 +74,24 @@ func AtLeastOne(rt *RuntimeContext, flags ...string) error {
return FlagErrorf("specify at least one of %s", strings.Join(names, " or "))
}
// AtLeastOneTyped checks that at least one of the given flags is set.
func AtLeastOneTyped(rt *RuntimeContext, flags ...string) error {
for _, f := range flags {
if rt.Str(f) != "" {
return nil
}
}
names := make([]string, len(flags))
for i, f := range flags {
names[i] = "--" + f
}
return ValidationErrorf("specify at least one of %s", strings.Join(names, " or ")).
WithParams(invalidParams(names, "required; specify at least one")...)
}
// ExactlyOne checks that exactly one of the given flags is set.
//
// Deprecated: use ExactlyOneTyped for typed error envelopes.
func ExactlyOne(rt *RuntimeContext, flags ...string) error {
if err := AtLeastOne(rt, flags...); err != nil {
return err
@@ -54,8 +99,18 @@ func ExactlyOne(rt *RuntimeContext, flags ...string) error {
return MutuallyExclusive(rt, flags...)
}
// ExactlyOneTyped checks that exactly one of the given flags is set.
func ExactlyOneTyped(rt *RuntimeContext, flags ...string) error {
if err := AtLeastOneTyped(rt, flags...); err != nil {
return err
}
return MutuallyExclusiveTyped(rt, flags...)
}
// ValidatePageSize validates that the named flag (if set) is an integer within [minVal, maxVal].
// It returns the parsed value (or defaultVal if the flag is empty) and any validation error.
//
// Deprecated: use ValidatePageSizeTyped for typed error envelopes.
func ValidatePageSize(rt *RuntimeContext, flagName string, defaultVal, minVal, maxVal int) (int, error) {
s := rt.Str(flagName)
if s == "" {
@@ -71,6 +126,25 @@ func ValidatePageSize(rt *RuntimeContext, flagName string, defaultVal, minVal, m
return n, nil
}
// ValidatePageSizeTyped validates that the named flag (if set) is an integer within [minVal, maxVal].
// It returns the parsed value (or defaultVal if the flag is empty) and any validation error.
func ValidatePageSizeTyped(rt *RuntimeContext, flagName string, defaultVal, minVal, maxVal int) (int, error) {
s := rt.Str(flagName)
param := "--" + flagName
if s == "" {
return defaultVal, nil
}
n, err := strconv.Atoi(s)
if err != nil {
return 0, ValidationErrorf("invalid --%s %q: must be an integer", flagName, s).WithParam(param)
}
if n < minVal || n > maxVal {
return 0, ValidationErrorf("invalid --%s %d: must be between %d and %d", flagName, n, minVal, maxVal).
WithParam(param)
}
return n, nil
}
// ParseIntBounded parses an int flag and clamps it to [min, max].
func ParseIntBounded(rt *RuntimeContext, name string, min, max int) int {
v := rt.Int(name)
@@ -87,13 +161,26 @@ func ParseIntBounded(rt *RuntimeContext, name string, min, max int) int {
// working directory. It catches traversal, symlink escape, and control
// characters by delegating to FileIO.ResolvePath. Works for both file and
// directory paths.
//
// Deprecated: use ValidateSafePathTyped for typed error envelopes.
func ValidateSafePath(fio fileio.FileIO, path string) error {
_, err := fio.ResolvePath(path)
return err
}
// ValidateSafePathTyped ensures path resolves within the current working directory.
func ValidateSafePathTyped(fio fileio.FileIO, path string) error {
_, err := fio.ResolvePath(path)
if err != nil {
return ValidationErrorf("%s", err).WithCause(err)
}
return nil
}
// RejectDangerousChars returns an error if value contains ASCII control
// characters or dangerous Unicode code points.
//
// Deprecated: use RejectDangerousCharsTyped for typed error envelopes.
func RejectDangerousChars(paramName, value string) error {
for _, r := range value {
if r < 0x20 && r != '\t' && r != '\n' {
@@ -108,3 +195,31 @@ func RejectDangerousChars(paramName, value string) error {
}
return nil
}
// RejectDangerousCharsTyped returns an error if value contains ASCII control
// characters or dangerous Unicode code points.
func RejectDangerousCharsTyped(paramName, value string) error {
for _, r := range value {
if r < 0x20 && r != '\t' && r != '\n' {
return ValidationErrorf("parameter %q contains control character U+%04X", paramName, r).
WithParam(paramName)
}
if r == 0x7F {
return ValidationErrorf("parameter %q contains DEL character", paramName).
WithParam(paramName)
}
if IsDangerousUnicode(r) {
return ValidationErrorf("parameter %q contains dangerous Unicode character U+%04X", paramName, r).
WithParam(paramName)
}
}
return nil
}
func invalidParams(names []string, reason string) []errs.InvalidParam {
params := make([]errs.InvalidParam, len(names))
for i, name := range names {
params[i] = errs.InvalidParam{Name: name, Reason: reason}
}
return params
}

View File

@@ -11,10 +11,31 @@ import (
// ValidateChatID checks if a chat ID has valid format (oc_ prefix).
// Also extracts token from URL if provided.
//
// Deprecated: use ValidateChatIDTyped for typed error envelopes.
func ValidateChatID(input string) (string, error) {
chatID, msg := normalizeChatID(input)
if msg != "" {
return "", output.ErrValidation("%s", msg)
}
return chatID, nil
}
// ValidateChatIDTyped checks if a chat ID has valid format (oc_ prefix).
// Also extracts token from URL if provided. param names the flag being
// validated (e.g. "--chat-ids") and is recorded on the typed error.
func ValidateChatIDTyped(param, input string) (string, error) {
chatID, msg := normalizeChatID(input)
if msg != "" {
return "", ValidationErrorf("%s", msg).WithParam(param)
}
return chatID, nil
}
func normalizeChatID(input string) (string, string) {
input = strings.TrimSpace(input)
if input == "" {
return "", output.ErrValidation("chat ID cannot be empty")
return "", "chat ID cannot be empty"
}
// Extract from URL if present
if strings.Contains(input, "feishu.cn") || strings.Contains(input, "larksuite.com") {
@@ -28,19 +49,40 @@ func ValidateChatID(input string) (string, error) {
}
}
if !strings.HasPrefix(input, "oc_") {
return "", output.ErrValidation("invalid chat ID format, should start with 'oc_' (e.g., oc_abc123)")
return "", "invalid chat ID format, should start with 'oc_' (e.g., oc_abc123)"
}
return input, nil
return input, ""
}
// ValidateUserID checks if a user ID has valid format (ou_ prefix).
//
// Deprecated: use ValidateUserIDTyped for typed error envelopes.
func ValidateUserID(input string) (string, error) {
userID, msg := normalizeUserID(input)
if msg != "" {
return "", output.ErrValidation("%s", msg)
}
return userID, nil
}
// ValidateUserIDTyped checks if a user ID has valid format (ou_ prefix).
// param names the flag being validated (e.g. "--creator-ids") and is
// recorded on the typed error.
func ValidateUserIDTyped(param, input string) (string, error) {
userID, msg := normalizeUserID(input)
if msg != "" {
return "", ValidationErrorf("%s", msg).WithParam(param)
}
return userID, nil
}
func normalizeUserID(input string) (string, string) {
input = strings.TrimSpace(input)
if input == "" {
return "", output.ErrValidation("user ID cannot be empty")
return "", "user ID cannot be empty"
}
if !strings.HasPrefix(input, "ou_") {
return "", output.ErrValidation("invalid user ID format, should start with 'ou_' (e.g., ou_abc123)")
return "", "invalid user ID format, should start with 'ou_' (e.g., ou_abc123)"
}
return input, nil
return input, ""
}

View File

@@ -4,10 +4,14 @@
package common
import (
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/larksuite/cli/errs"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/vfs/localfileio"
"github.com/spf13/cobra"
)
@@ -26,6 +30,24 @@ func newTestRuntime(flags map[string]string) *RuntimeContext {
return &RuntimeContext{Cmd: cmd}
}
func assertValidationParam(t *testing.T, err error, param string) *errs.ValidationError {
t.Helper()
if err == nil {
t.Fatal("expected validation error, got nil")
}
var validationErr *errs.ValidationError
if !errors.As(err, &validationErr) {
t.Fatalf("expected *errs.ValidationError, got %T: %v", err, err)
}
if validationErr.Subtype != errs.SubtypeInvalidArgument {
t.Fatalf("Subtype = %q, want %q", validationErr.Subtype, errs.SubtypeInvalidArgument)
}
if param != "" && validationErr.Param != param {
t.Fatalf("Param = %q, want %q", validationErr.Param, param)
}
return validationErr
}
func TestMutuallyExclusive(t *testing.T) {
tests := []struct {
name string
@@ -69,6 +91,109 @@ func TestMutuallyExclusive(t *testing.T) {
}
}
func TestValidationErrorf_ReturnsTypedInvalidArgument(t *testing.T) {
err := ValidationErrorf("bad %s", "flag")
validationErr := assertValidationParam(t, err, "")
if validationErr.Message != "bad flag" {
t.Fatalf("Message = %q, want %q", validationErr.Message, "bad flag")
}
}
func TestTypedFlagGroupHelpers_ReturnValidationParams(t *testing.T) {
t.Run("mutually exclusive", func(t *testing.T) {
rt := newTestRuntime(map[string]string{"a": "x", "b": "y"})
validationErr := assertValidationParam(t, MutuallyExclusiveTyped(rt, "a", "b"), "")
if len(validationErr.Params) != 2 {
t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params)
}
if validationErr.Params[0].Name != "--a" || validationErr.Params[1].Name != "--b" {
t.Fatalf("Params names = %+v, want --a/--b", validationErr.Params)
}
})
t.Run("at least one", func(t *testing.T) {
rt := newTestRuntime(map[string]string{"a": "", "b": ""})
validationErr := assertValidationParam(t, AtLeastOneTyped(rt, "a", "b"), "")
if len(validationErr.Params) != 2 {
t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params)
}
if !strings.Contains(validationErr.Message, "--a or --b") {
t.Fatalf("Message = %q, want flag group", validationErr.Message)
}
})
t.Run("exactly one", func(t *testing.T) {
rt := newTestRuntime(map[string]string{"a": "x", "b": "y"})
validationErr := assertValidationParam(t, ExactlyOneTyped(rt, "a", "b"), "")
if len(validationErr.Params) != 2 {
t.Fatalf("Params len = %d, want 2: %+v", len(validationErr.Params), validationErr.Params)
}
})
}
func TestValidatePageSizeTyped_ReturnsTypedValidation(t *testing.T) {
rt := newTestRuntime(map[string]string{"page-size": "nope"})
_, err := ValidatePageSizeTyped(rt, "page-size", 10, 1, 20)
assertValidationParam(t, err, "--page-size")
rt = newTestRuntime(map[string]string{"page-size": "30"})
_, err = ValidatePageSizeTyped(rt, "page-size", 10, 1, 20)
assertValidationParam(t, err, "--page-size")
}
func TestValidateIDTyped_ReturnsTypedValidation(t *testing.T) {
chatID, err := ValidateChatIDTyped("--chat-ids", "https://example.feishu.cn/foo/oc_abc")
if err != nil {
t.Fatalf("ValidateChatIDTyped valid URL: %v", err)
}
if chatID != "oc_abc" {
t.Fatalf("chatID = %q, want oc_abc", chatID)
}
assertValidationParam(t, func() error {
_, err := ValidateChatIDTyped("--chat-ids", "bad")
return err
}(), "--chat-ids")
assertValidationParam(t, func() error {
_, err := ValidateUserIDTyped("--creator-ids", "bad")
return err
}(), "--creator-ids")
}
func TestRejectDangerousCharsTyped_ReturnsTypedValidation(t *testing.T) {
err := RejectDangerousCharsTyped("--query", "bad\x01")
validationErr := assertValidationParam(t, err, "--query")
if !strings.Contains(validationErr.Message, "control character") {
t.Fatalf("Message = %q, want control character", validationErr.Message)
}
}
func TestWrapInputStatErrorTyped_ReturnsTypedValidation(t *testing.T) {
cause := &fileio.PathValidationError{Err: errors.New("outside cwd")}
err := WrapInputStatErrorTyped(cause)
validationErr := assertValidationParam(t, err, "")
if !strings.Contains(validationErr.Message, "unsafe file path") {
t.Fatalf("Message = %q, want unsafe file path", validationErr.Message)
}
if !errors.Is(err, fileio.ErrPathValidation) {
t.Fatalf("expected errors.Is(fileio.ErrPathValidation) to match")
}
}
func TestWrapSaveErrorTyped_ClassifiesPathAndFileIO(t *testing.T) {
pathErr := &fileio.PathValidationError{Err: errors.New("outside cwd")}
assertValidationParam(t, WrapSaveErrorTyped(pathErr), "")
mkdirErr := &fileio.MkdirError{Err: errors.New("permission denied")}
err := WrapSaveErrorTyped(mkdirErr)
var internalErr *errs.InternalError
if !errors.As(err, &internalErr) {
t.Fatalf("expected *errs.InternalError, got %T: %v", err, err)
}
if internalErr.Subtype != errs.SubtypeFileIO {
t.Fatalf("Subtype = %q, want %q", internalErr.Subtype, errs.SubtypeFileIO)
}
}
func TestAtLeastOne(t *testing.T) {
tests := []struct {
name string
@@ -246,3 +371,20 @@ func TestValidateSafePath_AllowsNonExistentPath(t *testing.T) {
t.Fatalf("expected no error for non-existent path, got: %v", err)
}
}
// TestValidateSafePathTyped_ReturnsTypedValidation verifies that an escaping
// path is rejected with a typed validation error and a safe path passes.
func TestValidateSafePathTyped_ReturnsTypedValidation(t *testing.T) {
outside := t.TempDir()
workDir := t.TempDir()
chdirForTest(t, workDir)
if err := os.Symlink(outside, filepath.Join(workDir, "evil_out")); err != nil {
t.Fatalf("Symlink: %v", err)
}
assertValidationParam(t, ValidateSafePathTyped(&localfileio.LocalFileIO{}, "evil_out"), "")
if err := ValidateSafePathTyped(&localfileio.LocalFileIO{}, "new_output_dir"); err != nil {
t.Fatalf("expected no error for safe path, got: %v", err)
}
}

View File

@@ -354,7 +354,7 @@ func parseDriveSearchPageSize(raw string) (int, error) {
// server-side failure or empty result.
func validateDriveSearchIDs(spec driveSearchSpec) error {
for _, id := range spec.CreatorIDs {
if _, err := common.ValidateUserID(id); err != nil {
if _, err := common.ValidateUserIDTyped("--creator-ids", id); err != nil {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--creator-ids %q: %s", id, err).WithParam("--creator-ids")
}
}
@@ -362,7 +362,7 @@ func validateDriveSearchIDs(spec driveSearchSpec) error {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--chat-ids: max %d values per request, got %d", driveSearchMaxChatIDs, n).WithParam("--chat-ids")
}
for _, id := range spec.ChatIDs {
if _, err := common.ValidateChatID(id); err != nil {
if _, err := common.ValidateChatIDTyped("--chat-ids", id); err != nil {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--chat-ids %q: %s", id, err).WithParam("--chat-ids")
}
}
@@ -370,7 +370,7 @@ func validateDriveSearchIDs(spec driveSearchSpec) error {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--sharer-ids: max %d values per request, got %d", driveSearchMaxSharerIDs, n).WithParam("--sharer-ids")
}
for _, id := range spec.SharerIDs {
if _, err := common.ValidateUserID(id); err != nil {
if _, err := common.ValidateUserIDTyped("--sharer-ids", id); err != nil {
return errs.NewValidationError(errs.SubtypeInvalidArgument, "--sharer-ids %q: %s", id, err).WithParam("--sharer-ids")
}
}