refactor: migrate common/client/im to FileIO and add localfileio tests (#322)

* refactor: migrate common/client/im to FileIO and add localfileio tests

- runner resolveInputFlags: replace validate.SafeInputPath + vfs.ReadFile
  with FileIO.Open + io.ReadAll
- SaveResponse: delegate to FileIO.Save + ResolvePath
- cmd/api, cmd/service: pass FileIO to ResponseOptions
- im: replace validate.SafeLocalFlagPath with RuntimeContext.ValidatePath,
  migrate download/upload to FileIO.Save/Open/Stat
- Add path_test.go and atomicwrite_test.go for localfileio
- Add validate_media_test.go for im media flag validation
- Adapt test mocks to fileio.FileInfo interface
This commit is contained in:
tuxedomm
2026-04-08 17:31:21 +08:00
committed by GitHub
parent adef52ada5
commit f5a8fbf8f1
21 changed files with 1174 additions and 286 deletions

View File

@@ -207,6 +207,7 @@ func apiRun(opts *APIOptions) error {
JqExpr: opts.JqExpr,
Out: out,
ErrOut: f.IOStreams.ErrOut,
FileIO: f.ResolveFileIO(opts.Ctx),
})
// MarkRaw tells root error handler to skip enrichPermissionError,
// preserving the original API error detail (log_id, troubleshooter, etc.).

View File

@@ -250,6 +250,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error {
JqExpr: opts.JqExpr,
Out: out,
ErrOut: f.IOStreams.ErrOut,
FileIO: f.ResolveFileIO(opts.Ctx),
CheckError: checkErr,
})
}

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package fileio
import "errors"
// ErrPathValidation indicates the path failed security validation
// (traversal, absolute, control chars, symlink escape, etc.).
var ErrPathValidation = errors.New("path validation failed")
// PathValidationError wraps a path validation error.
// errors.Is(err, ErrPathValidation) returns true.
// errors.Is(err, <original OS error>) also works via the chain.
type PathValidationError struct {
Err error // original error
}
func (e *PathValidationError) Error() string { return e.Err.Error() }
func (e *PathValidationError) Unwrap() []error {
return []error{ErrPathValidation, e.Err}
}
// MkdirError indicates parent directory creation failed.
// Use errors.As(err, &fileio.MkdirError{}) to match.
type MkdirError struct {
Err error
}
func (e *MkdirError) Error() string { return e.Err.Error() }
func (e *MkdirError) Unwrap() error { return e.Err }
// WriteError indicates file write failed.
// Use errors.As(err, &fileio.WriteError{}) to match.
type WriteError struct {
Err error
}
func (e *WriteError) Error() string { return e.Err.Error() }
func (e *WriteError) Unwrap() error { return e.Err }

View File

@@ -6,18 +6,17 @@ package client
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"path/filepath"
"strings"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/util"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
)
// ── Response routing ──
@@ -29,6 +28,7 @@ type ResponseOptions struct {
JqExpr string // if set, apply jq filter instead of Format
Out io.Writer // stdout
ErrOut io.Writer // stderr
FileIO fileio.FileIO // file transfer abstraction; required when saving files (--output or binary response)
// CheckError is called on parsed JSON results. Nil defaults to CheckLarkResponse.
CheckError func(interface{}) error
}
@@ -61,7 +61,7 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
return apiErr
}
if opts.OutputPath != "" {
return saveAndPrint(resp, opts.OutputPath, opts.Out)
return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out)
}
if opts.JqExpr != "" {
return output.JqFilter(opts.Out, result, opts.JqExpr)
@@ -75,11 +75,11 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
return output.ErrValidation("--jq requires a JSON response (got Content-Type: %s)", ct)
}
if opts.OutputPath != "" {
return saveAndPrint(resp, opts.OutputPath, opts.Out)
return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out)
}
// No --output: auto-save with derived filename.
meta, err := SaveResponse(resp, ResolveFilename(resp))
meta, err := SaveResponse(opts.FileIO, resp, ResolveFilename(resp))
if err != nil {
return output.Errorf(output.ExitInternal, "file_error", "%s", err)
}
@@ -88,8 +88,8 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
return nil
}
func saveAndPrint(resp *larkcore.ApiResp, path string, w io.Writer) error {
meta, err := SaveResponse(resp, path)
func saveAndPrint(fio fileio.FileIO, resp *larkcore.ApiResp, path string, w io.Writer) error {
meta, err := SaveResponse(fio, resp, path)
if err != nil {
return output.Errorf(output.ExitInternal, "file_error", "%s", err)
}
@@ -119,23 +119,34 @@ func ParseJSONResponse(resp *larkcore.ApiResp) (interface{}, error) {
// ── File saving ──
// SaveResponse writes an API response body to the given outputPath and returns metadata.
func SaveResponse(resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) {
safePath, err := validate.SafeOutputPath(outputPath)
// It delegates to FileIO.Save for path validation and atomic write; fio must not be nil.
func SaveResponse(fio fileio.FileIO, resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) {
result, err := fio.Save(outputPath, fileio.SaveOptions{
ContentType: resp.Header.Get("Content-Type"),
ContentLength: int64(len(resp.RawBody)),
}, bytes.NewReader(resp.RawBody))
if err != nil {
return nil, fmt.Errorf("unsafe output path: %s", err)
var me *fileio.MkdirError
var we *fileio.WriteError
switch {
case errors.Is(err, fileio.ErrPathValidation):
return nil, fmt.Errorf("unsafe output path: %s", err)
case errors.As(err, &me):
return nil, fmt.Errorf("create directory: %s", err)
case errors.As(err, &we):
return nil, fmt.Errorf("cannot write file: %s", err)
default:
return nil, fmt.Errorf("cannot write file: %s", err)
}
}
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
return nil, fmt.Errorf("create directory: %s", err)
resolvedPath, err := fio.ResolvePath(outputPath)
if err != nil || resolvedPath == "" {
resolvedPath = outputPath
}
if err := validate.AtomicWrite(safePath, resp.RawBody, 0644); err != nil {
return nil, fmt.Errorf("cannot write file: %s", err)
}
return map[string]interface{}{
"saved_path": safePath,
"size_bytes": len(resp.RawBody),
"saved_path": resolvedPath,
"size_bytes": result.Size(),
"content_type": resp.Header.Get("Content-Type"),
}, nil
}

View File

@@ -16,6 +16,7 @@ import (
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/vfs/localfileio"
)
func newApiResp(body []byte, headers map[string]string) *larkcore.ApiResp {
@@ -162,11 +163,11 @@ func TestSaveResponse(t *testing.T) {
body := []byte("hello binary data")
resp := newApiResp(body, map[string]string{"Content-Type": "application/octet-stream"})
meta, err := SaveResponse(resp, "test_output.bin")
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "test_output.bin")
if err != nil {
t.Fatalf("SaveResponse failed: %v", err)
}
if meta["size_bytes"] != len(body) {
if meta["size_bytes"] != int64(len(body)) {
t.Errorf("expected size_bytes=%d, got %v", len(body), meta["size_bytes"])
}
@@ -188,7 +189,7 @@ func TestSaveResponse_CreatesDir(t *testing.T) {
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
meta, err := SaveResponse(resp, filepath.Join("sub", "deep", "out.bin"))
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, filepath.Join("sub", "deep", "out.bin"))
if err != nil {
t.Fatalf("SaveResponse with nested dir failed: %v", err)
}
@@ -207,6 +208,7 @@ func TestHandleResponse_JSON(t *testing.T) {
err := HandleResponse(resp, ResponseOptions{
Out: &out,
ErrOut: &errOut,
FileIO: &localfileio.LocalFileIO{},
})
if err != nil {
t.Fatalf("HandleResponse failed: %v", err)
@@ -225,6 +227,7 @@ func TestHandleResponse_JSONWithError(t *testing.T) {
err := HandleResponse(resp, ResponseOptions{
Out: &out,
ErrOut: &errOut,
FileIO: &localfileio.LocalFileIO{},
})
if err == nil {
t.Error("expected error for non-zero code")
@@ -275,6 +278,7 @@ func TestHandleResponse_BinaryAutoSave(t *testing.T) {
err := HandleResponse(resp, ResponseOptions{
Out: &out,
ErrOut: &errOut,
FileIO: &localfileio.LocalFileIO{},
})
if err != nil {
t.Fatalf("HandleResponse binary failed: %v", err)
@@ -298,6 +302,7 @@ func TestHandleResponse_BinaryWithOutput(t *testing.T) {
OutputPath: "out.png",
Out: &out,
ErrOut: &errOut,
FileIO: &localfileio.LocalFileIO{},
})
if err != nil {
t.Fatalf("HandleResponse with output path failed: %v", err)
@@ -312,7 +317,7 @@ func TestHandleResponse_NonJSONError_404(t *testing.T) {
resp := newApiRespWithStatus(404, []byte("404 page not found"), map[string]string{"Content-Type": "text/plain"})
var out, errOut bytes.Buffer
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
if err == nil {
t.Fatal("expected error for 404 text/plain")
}
@@ -330,7 +335,7 @@ func TestHandleResponse_NonJSONError_502(t *testing.T) {
resp := newApiRespWithStatus(502, []byte("<html>Bad Gateway</html>"), map[string]string{"Content-Type": "text/html"})
var out, errOut bytes.Buffer
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
if err == nil {
t.Fatal("expected error for 502 text/html")
}
@@ -353,7 +358,7 @@ func TestHandleResponse_200TextPlain_SavesFile(t *testing.T) {
resp := newApiRespWithStatus(200, []byte("plain text file content"), map[string]string{"Content-Type": "text/plain"})
var out, errOut bytes.Buffer
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
if err != nil {
t.Fatalf("expected no error for 200 text/plain, got: %v", err)
}
@@ -379,12 +384,53 @@ func TestHandleResponse_BinaryWithJq_RejectsNonJSON(t *testing.T) {
}
}
func TestSaveResponse_RejectsPathTraversal(t *testing.T) {
dir := t.TempDir()
origWd, _ := os.Getwd()
os.Chdir(dir)
defer os.Chdir(origWd)
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
_, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "../../evil.txt")
if err == nil {
t.Fatal("expected error for path traversal")
}
if !strings.Contains(err.Error(), "unsafe output path") {
t.Errorf("expected 'unsafe output path' wrapper, got: %v", err)
}
}
func TestSaveResponse_RejectsAbsolutePath(t *testing.T) {
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
_, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "/tmp/evil.txt")
if err == nil {
t.Fatal("expected error for absolute path")
}
}
func TestSaveResponse_MetadataContainsAbsolutePath(t *testing.T) {
dir := t.TempDir()
origWd, _ := os.Getwd()
os.Chdir(dir)
defer os.Chdir(origWd)
resp := newApiResp([]byte("x"), map[string]string{"Content-Type": "text/plain"})
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "rel.txt")
if err != nil {
t.Fatalf("SaveResponse failed: %v", err)
}
savedPath, _ := meta["saved_path"].(string)
if !filepath.IsAbs(savedPath) {
t.Errorf("saved_path should be absolute, got %q", savedPath)
}
}
func TestHandleResponse_403JSON_CheckLarkResponse(t *testing.T) {
body := []byte(`{"code":99991400,"msg":"invalid token"}`)
resp := newApiRespWithStatus(403, body, map[string]string{"Content-Type": "application/json"})
var out, errOut bytes.Buffer
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
if err == nil {
t.Fatal("expected error for 403 JSON with non-zero code")
}

View File

@@ -11,13 +11,26 @@ import (
"testing"
_ "github.com/larksuite/cli/extension/credential/env"
"github.com/larksuite/cli/extension/fileio"
exttransport "github.com/larksuite/cli/extension/transport"
internalauth "github.com/larksuite/cli/internal/auth"
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/credential"
"github.com/larksuite/cli/internal/envvars"
"github.com/larksuite/cli/internal/vfs/localfileio"
)
type countingFileIOProvider struct {
resolveCalls int
}
func (p *countingFileIOProvider) Name() string { return "counting" }
func (p *countingFileIOProvider) ResolveFileIO(context.Context) fileio.FileIO {
p.resolveCalls++
return &localfileio.LocalFileIO{}
}
func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) {
t.Setenv(envvars.CliAppID, "")
t.Setenv(envvars.CliAppSecret, "")
@@ -198,6 +211,28 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin
}
}
func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing.T) {
prev := fileio.GetProvider()
provider := &countingFileIOProvider{}
fileio.Register(provider)
t.Cleanup(func() { fileio.Register(prev) })
f := NewDefault(InvocationContext{})
if f.FileIOProvider != provider {
t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider)
}
if provider.resolveCalls != 0 {
t.Fatalf("ResolveFileIO() calls after NewDefault() = %d, want 0", provider.resolveCalls)
}
if got := f.ResolveFileIO(context.Background()); got == nil {
t.Fatal("ResolveFileIO() = nil, want non-nil")
}
if provider.resolveCalls != 1 {
t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls)
}
}
type stubTransportProvider struct {
interceptor exttransport.Interceptor
}

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package localfileio
import (
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestAtomicWrite_WritesContentAndPermissionCorrectly(t *testing.T) {
// GIVEN: a target path in a temp directory
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
data := []byte(`{"key":"value"}`)
// WHEN: AtomicWrite writes data with 0644 permission
if err := AtomicWrite(path, data, 0644); err != nil {
t.Fatalf("AtomicWrite failed: %v", err)
}
// THEN: file content matches exactly
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if string(got) != string(data) {
t.Errorf("content = %q, want %q", got, data)
}
}
func TestAtomicWrite_SetsRestrictivePermission(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission test not reliable on Windows")
}
// GIVEN: a target path
dir := t.TempDir()
path := filepath.Join(dir, "secret.json")
// WHEN: AtomicWrite writes with 0600 permission
if err := AtomicWrite(path, []byte("secret"), 0600); err != nil {
t.Fatalf("AtomicWrite failed: %v", err)
}
// THEN: file permission is exactly 0600 (owner read-write only)
info, _ := os.Stat(path)
if perm := info.Mode().Perm(); perm != 0600 {
t.Errorf("permission = %04o, want 0600", perm)
}
}
func TestAtomicWrite_OverwritesExistingFile(t *testing.T) {
// GIVEN: an existing file with old content
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
AtomicWrite(path, []byte("old"), 0644)
// WHEN: AtomicWrite overwrites with new content
if err := AtomicWrite(path, []byte("new"), 0644); err != nil {
t.Fatalf("second write failed: %v", err)
}
// THEN: file contains new content
got, _ := os.ReadFile(path)
if string(got) != "new" {
t.Errorf("content = %q, want %q", got, "new")
}
}
func TestAtomicWrite_LeavesNoResidualTempFileOnError(t *testing.T) {
// GIVEN: a target path in a non-existent nested directory
path := filepath.Join(t.TempDir(), "nonexistent", "subdir", "file.txt")
// WHEN: AtomicWrite fails (parent directory doesn't exist)
err := AtomicWrite(path, []byte("data"), 0644)
// THEN: the write fails
if err == nil {
t.Fatal("expected error writing to nonexistent dir")
}
// THEN: no .tmp files are left behind
parentDir := filepath.Dir(filepath.Dir(path))
entries, _ := os.ReadDir(parentDir)
for _, e := range entries {
if filepath.Ext(e.Name()) == ".tmp" {
t.Errorf("residual temp file found: %s", e.Name())
}
}
}
func TestAtomicWrite_PreservesOriginalFileOnFailure(t *testing.T) {
// GIVEN: an existing file with known content
dir := t.TempDir()
original := []byte("original content")
path := filepath.Join(dir, "file.json")
if err := AtomicWrite(path, original, 0644); err != nil {
t.Fatal(err)
}
// WHEN: AtomicWrite targets a non-existent directory (guaranteed to fail even as root)
badPath := filepath.Join(dir, "no", "such", "dir", "file.json")
err := AtomicWrite(badPath, []byte("new"), 0644)
// THEN: write fails
if err == nil {
t.Fatal("expected error writing to non-existent dir")
}
// THEN: the original file at the valid path is untouched
got, _ := os.ReadFile(path)
if string(got) != string(original) {
t.Errorf("original file corrupted: got %q, want %q", got, original)
}
}
func TestAtomicWrite_HandlesCorrectlyUnderConcurrentWrites(t *testing.T) {
// GIVEN: a target file that will be written by 20 concurrent goroutines
dir := t.TempDir()
path := filepath.Join(dir, "concurrent.json")
// WHEN: 20 goroutines write simultaneously
var wg sync.WaitGroup
for i := range 20 {
wg.Add(1)
go func(n int) {
defer wg.Done()
data := []byte(`{"n":` + string(rune('0'+n%10)) + `}`)
AtomicWrite(path, data, 0644)
}(i)
}
wg.Wait()
// THEN: file exists and is valid (not corrupted by interleaved writes)
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if len(got) == 0 {
t.Error("file is empty after concurrent writes")
}
}

View File

@@ -34,7 +34,7 @@ type LocalFileIO struct{}
func (l *LocalFileIO) Open(name string) (fileio.File, error) {
safePath, err := SafeInputPath(name)
if err != nil {
return nil, err
return nil, &fileio.PathValidationError{Err: err}
}
return vfs.Open(safePath)
}
@@ -43,7 +43,7 @@ func (l *LocalFileIO) Open(name string) (fileio.File, error) {
func (l *LocalFileIO) Stat(name string) (fileio.FileInfo, error) {
safePath, err := SafeInputPath(name)
if err != nil {
return nil, err
return nil, &fileio.PathValidationError{Err: err}
}
return vfs.Stat(safePath)
}
@@ -55,7 +55,11 @@ func (r *saveResult) Size() int64 { return r.size }
// ResolvePath returns the validated absolute path for the given output path.
func (l *LocalFileIO) ResolvePath(path string) (string, error) {
return SafeOutputPath(path)
resolved, err := SafeOutputPath(path)
if err != nil {
return "", &fileio.PathValidationError{Err: err}
}
return resolved, nil
}
// Save writes body to path atomically after validating the output path.
@@ -64,14 +68,14 @@ func (l *LocalFileIO) ResolvePath(path string) (string, error) {
func (l *LocalFileIO) Save(path string, _ fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) {
safePath, err := SafeOutputPath(path)
if err != nil {
return nil, err
return nil, &fileio.PathValidationError{Err: err}
}
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
return nil, err
return nil, &fileio.MkdirError{Err: err}
}
n, err := AtomicWriteFromReader(safePath, body, 0600)
if err != nil {
return nil, err
return nil, &fileio.WriteError{Err: err}
}
return &saveResult{size: n}, nil
}

View File

@@ -0,0 +1,306 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package localfileio
import (
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/larksuite/cli/extension/fileio"
)
// testChdir temporarily changes the working directory for a test.
// Not compatible with t.Parallel().
func testChdir(t *testing.T, dir string) {
t.Helper()
orig, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
if err := os.Chdir(dir); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { os.Chdir(orig) })
}
// ── Provider ──
func TestProvider_Name(t *testing.T) {
p := &Provider{}
if got := p.Name(); got != "local" {
t.Errorf("Provider.Name() = %q, want %q", got, "local")
}
}
func TestProvider_ResolveFileIO(t *testing.T) {
p := &Provider{}
fio := p.ResolveFileIO(nil)
if fio == nil {
t.Fatal("Provider.ResolveFileIO returned nil")
}
if _, ok := fio.(*LocalFileIO); !ok {
t.Errorf("expected *LocalFileIO, got %T", fio)
}
}
// ── Open ──
func TestLocalFileIO_Open_ValidFile(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
content := []byte("hello world")
os.WriteFile("test.txt", content, 0644)
fio := &LocalFileIO{}
f, err := fio.Open("test.txt")
if err != nil {
t.Fatalf("Open failed: %v", err)
}
defer f.Close()
got, err := io.ReadAll(f)
if err != nil {
t.Fatalf("ReadAll failed: %v", err)
}
if string(got) != string(content) {
t.Errorf("content = %q, want %q", got, content)
}
}
func TestLocalFileIO_Open_RejectsTraversal(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.Open("../../etc/passwd")
if err == nil {
t.Error("expected error for path traversal")
}
}
func TestLocalFileIO_Open_RejectsAbsolutePath(t *testing.T) {
fio := &LocalFileIO{}
_, err := fio.Open("/etc/passwd")
if err == nil {
t.Error("expected error for absolute path")
}
if err != nil && !strings.Contains(err.Error(), "relative path") {
t.Errorf("error should mention relative path, got: %v", err)
}
}
func TestLocalFileIO_Open_NonexistentFile(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.Open("nonexistent.txt")
if err == nil {
t.Error("expected error for nonexistent file")
}
}
// ── Stat ──
func TestLocalFileIO_Stat_ValidFile(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
os.WriteFile("stat.txt", []byte("12345"), 0644)
fio := &LocalFileIO{}
info, err := fio.Stat("stat.txt")
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
if info.Size() != 5 {
t.Errorf("Size() = %d, want 5", info.Size())
}
if info.IsDir() {
t.Error("expected IsDir() = false")
}
}
func TestLocalFileIO_Stat_RejectsTraversal(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.Stat("../../etc/passwd")
if err == nil {
t.Error("expected error for path traversal")
}
if err != nil && os.IsNotExist(err) {
t.Error("traversal should not be os.IsNotExist, should be a validation error")
}
}
func TestLocalFileIO_Stat_NonexistentReturnsIsNotExist(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.Stat("nope.txt")
if err == nil {
t.Error("expected error for nonexistent file")
}
if !os.IsNotExist(err) {
t.Errorf("expected os.IsNotExist, got: %v", err)
}
}
// ── Save ──
func TestLocalFileIO_Save_WritesContent(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
body := strings.NewReader("saved content")
result, err := fio.Save("output.bin", fileio.SaveOptions{}, body)
if err != nil {
t.Fatalf("Save failed: %v", err)
}
if result.Size() != int64(len("saved content")) {
t.Errorf("Size() = %d, want %d", result.Size(), len("saved content"))
}
got, _ := os.ReadFile(filepath.Join(dir, "output.bin"))
if string(got) != "saved content" {
t.Errorf("file content = %q, want %q", got, "saved content")
}
}
func TestLocalFileIO_Save_CreatesParentDirs(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
body := strings.NewReader("nested")
_, err := fio.Save(filepath.Join("a", "b", "c.txt"), fileio.SaveOptions{}, body)
if err != nil {
t.Fatalf("Save with nested dir failed: %v", err)
}
got, _ := os.ReadFile(filepath.Join(dir, "a", "b", "c.txt"))
if string(got) != "nested" {
t.Errorf("file content = %q, want %q", got, "nested")
}
}
func TestLocalFileIO_Save_RejectsTraversal(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.Save("../../evil.txt", fileio.SaveOptions{}, strings.NewReader("bad"))
if err == nil {
t.Error("expected error for path traversal in Save")
}
}
func TestLocalFileIO_Save_RejectsAbsolutePath(t *testing.T) {
fio := &LocalFileIO{}
_, err := fio.Save("/tmp/evil.txt", fileio.SaveOptions{}, strings.NewReader("bad"))
if err == nil {
t.Error("expected error for absolute path in Save")
}
}
// ── ResolvePath ──
func TestLocalFileIO_ResolvePath_ReturnsAbsolute(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
resolved, err := fio.ResolvePath("file.txt")
if err != nil {
t.Fatalf("ResolvePath failed: %v", err)
}
if !filepath.IsAbs(resolved) {
t.Errorf("expected absolute path, got %q", resolved)
}
if filepath.Base(resolved) != "file.txt" {
t.Errorf("expected base name file.txt, got %q", filepath.Base(resolved))
}
}
func TestLocalFileIO_ResolvePath_RejectsTraversal(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
_, err := fio.ResolvePath("../../etc/passwd")
if err == nil {
t.Error("expected error for path traversal in ResolvePath")
}
}
func TestLocalFileIO_ResolvePath_RejectsAbsolute(t *testing.T) {
fio := &LocalFileIO{}
_, err := fio.ResolvePath("/etc/passwd")
if err == nil {
t.Error("expected error for absolute path in ResolvePath")
}
}
// ── Error message consistency ──
func TestLocalFileIO_ErrorMessages_ContainCorrectFlagName(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
// Open/Stat use SafeInputPath → errors should mention "--file"
_, err := fio.Open("/absolute/path")
if err == nil || !strings.Contains(err.Error(), "--file") {
t.Errorf("Open absolute path error should mention --file, got: %v", err)
}
_, err = fio.Stat("/absolute/path")
if err == nil || !strings.Contains(err.Error(), "--file") {
t.Errorf("Stat absolute path error should mention --file, got: %v", err)
}
// Save/ResolvePath use SafeOutputPath → errors should mention "--output"
_, err = fio.Save("/absolute/path", fileio.SaveOptions{}, strings.NewReader(""))
if err == nil || !strings.Contains(err.Error(), "--output") {
t.Errorf("Save absolute path error should mention --output, got: %v", err)
}
_, err = fio.ResolvePath("/absolute/path")
if err == nil || !strings.Contains(err.Error(), "--output") {
t.Errorf("ResolvePath absolute path error should mention --output, got: %v", err)
}
}
// ── Control character / Unicode rejection ──
func TestLocalFileIO_RejectsControlCharsInPath(t *testing.T) {
dir := t.TempDir()
testChdir(t, dir)
fio := &LocalFileIO{}
paths := []string{
"file\x00name.txt", // null byte
"file\x1fname.txt", // control char
"file\u200Bname.txt", // zero-width space
"file\u202Ename.txt", // bidi override
}
for _, p := range paths {
if _, err := fio.Open(p); err == nil {
t.Errorf("Open(%q) should reject control/dangerous chars", p)
}
if _, err := fio.Save(p, fileio.SaveOptions{}, strings.NewReader("")); err == nil {
t.Errorf("Save(%q) should reject control/dangerous chars", p)
}
}
}

View File

@@ -0,0 +1,245 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package localfileio
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestSafeOutputPath_RejectsPathTraversalAndDangerousInput(t *testing.T) {
for _, tt := range []struct {
name string
input string
wantErr bool
}{
// ── GIVEN: normal relative paths → THEN: allowed ──
{"normal file", "report.xlsx", false},
{"subdir file", "output/report.xlsx", false},
{"current dir explicit", "./file.txt", false},
{"nested subdir", "a/b/c/file.txt", false},
{"dot in name", "my.report.v2.xlsx", false},
{"space in name", "my file.txt", false},
{"unicode normal", "报告.xlsx", false},
{"dot-dot resolves to cwd", "subdir/..", false},
// ── GIVEN: path traversal via .. → THEN: rejected ──
{"dot-dot escape", "../../.ssh/authorized_keys", true},
{"dot-dot mid path", "subdir/../../etc/passwd", true},
{"triple dot-dot", "../../../etc/shadow", true},
// ── GIVEN: absolute paths → THEN: rejected ──
{"absolute path unix", "/etc/passwd", true},
{"absolute path root", "/tmp/evil", true},
// ── GIVEN: control characters in path → THEN: rejected ──
{"null byte", "file\x00.txt", true},
{"carriage return", "file\r.txt", true},
{"bell char", "file\x07.txt", true},
// ── GIVEN: dangerous Unicode in path → THEN: rejected ──
{"bidi RLO", "file\u202Ename.txt", true},
{"zero width space", "file\u200Bname.txt", true},
{"BOM char", "file\uFEFFname.txt", true},
{"line separator", "file\u2028name.txt", true},
{"bidi LRI", "file\u2066name.txt", true},
// ── GIVEN: looks dangerous but is actually safe → THEN: allowed ──
{"literal percent 2e", "%2e%2e/etc/passwd", false},
{"tilde path", "~/file.txt", false},
} {
t.Run(tt.name, func(t *testing.T) {
// WHEN: SafeOutputPath validates the path
_, err := SafeOutputPath(tt.input)
// THEN: error matches expectation
if (err != nil) != tt.wantErr {
t.Errorf("SafeOutputPath(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}
func TestSafeOutputPath_ReturnsCanonicalAbsolutePath(t *testing.T) {
// GIVEN: a clean temp directory as CWD
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(dir)
// WHEN: SafeOutputPath validates a relative path
got, err := SafeOutputPath("output/file.txt")
// THEN: returns the canonical absolute path for subsequent I/O
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := filepath.Join(dir, "output", "file.txt")
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSafeOutputPath_RejectsSymlinkEscapingCWD(t *testing.T) {
// GIVEN: a symlink in CWD pointing to /etc (outside CWD)
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(dir)
os.Symlink("/etc", filepath.Join(dir, "link-to-etc"))
// WHEN: SafeOutputPath validates a path through the symlink
_, err := SafeOutputPath("link-to-etc/passwd")
// THEN: rejected because the resolved path is outside CWD
if err == nil {
t.Error("expected error for symlink escaping CWD, got nil")
}
}
func TestSafeOutputPath_AllowsSymlinkWithinCWD(t *testing.T) {
// GIVEN: a symlink in CWD pointing to a subdirectory within CWD
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(dir)
os.MkdirAll(filepath.Join(dir, "real"), 0755)
os.Symlink(filepath.Join(dir, "real"), filepath.Join(dir, "link"))
// WHEN: SafeOutputPath validates a path through the internal symlink
got, err := SafeOutputPath("link/file.txt")
// THEN: allowed, resolved to the real path within CWD
if err != nil {
t.Fatalf("symlink within CWD should be allowed: %v", err)
}
want := filepath.Join(dir, "real", "file.txt")
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSafeOutputPath_ResolvesAncestorSymlinkWhenParentMissing(t *testing.T) {
// GIVEN: CWD contains a symlink "escape" → /etc, and the target path
// goes through "escape/sub/file.txt" where "sub" does not exist.
// The old code failed to resolve the symlink because the immediate
// parent ("escape/sub") didn't exist, leaving resolved un-anchored.
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(dir)
os.Symlink("/etc", filepath.Join(dir, "escape"))
// WHEN: SafeOutputPath validates a path through the symlink with missing intermediate dirs
_, err := SafeOutputPath("escape/nonexistent/file.txt")
// THEN: rejected — the resolved path is under /etc, outside CWD
if err == nil {
t.Error("expected error for symlink escaping CWD via non-existent parent, got nil")
}
}
func TestSafeOutputPath_DeepNonExistentPathStaysInCWD(t *testing.T) {
// GIVEN: a deeply nested non-existent path with no symlinks
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(dir)
// WHEN: SafeOutputPath validates "a/b/c/d/file.txt" (none of a/b/c/d exist)
got, err := SafeOutputPath("a/b/c/d/file.txt")
// THEN: allowed, resolved to canonical path under CWD
if err != nil {
t.Fatalf("deep non-existent path within CWD should be allowed: %v", err)
}
want := filepath.Join(dir, "a", "b", "c", "d", "file.txt")
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSafeUploadPath_AllowsTempFileAbsolutePath(t *testing.T) {
// GIVEN: a real temp file (absolute path under os.TempDir())
f, err := os.CreateTemp("", "upload-test-*.bin")
if err != nil {
t.Fatalf("CreateTemp: %v", err)
}
tmpPath := f.Name()
f.Close()
t.Cleanup(func() { os.Remove(tmpPath) })
// WHEN: SafeUploadPath validates the absolute temp path
_, err = SafeInputPath(tmpPath)
// THEN: absolute paths are rejected even in temp dir
if err == nil {
t.Fatal("expected error for absolute temp path, got nil")
}
}
func TestSafeUploadPath_RejectsNonTempAbsolutePath(t *testing.T) {
// GIVEN: an absolute path outside the temp directory
// WHEN / THEN: SafeUploadPath rejects it
_, err := SafeInputPath("/etc/passwd")
if err == nil {
t.Error("expected error for absolute non-temp path, got nil")
}
}
func TestSafeUploadPath_AcceptsRelativePath(t *testing.T) {
// GIVEN: a clean temp CWD with a real file
dir := t.TempDir()
dir, _ = filepath.EvalSymlinks(dir)
orig, _ := os.Getwd()
defer os.Chdir(orig)
os.Chdir(dir)
os.WriteFile(filepath.Join(dir, "upload.bin"), []byte("data"), 0600)
// WHEN: SafeUploadPath validates a relative path to an existing file
got, err := SafeInputPath("upload.bin")
// THEN: accepted and returned as absolute canonical path
if err != nil {
t.Fatalf("SafeUploadPath(relative) error = %v", err)
}
want := filepath.Join(dir, "upload.bin")
if got != want {
t.Errorf("SafeUploadPath(relative) = %q, want %q", got, want)
}
}
func Test_SafeInputPath_ErrorMessageContainsCorrectFlagName(t *testing.T) {
// GIVEN: an absolute path
// WHEN: SafeInputPath rejects it
_, err := SafeInputPath("/etc/passwd")
// THEN: error message mentions --file (not --output)
if err == nil {
t.Fatal("expected error for absolute path")
}
if !strings.Contains(err.Error(), "--file") {
t.Errorf("error should mention --file, got: %s", err.Error())
}
// WHEN: SafeOutputPath rejects it
_, err = SafeOutputPath("/etc/passwd")
// THEN: error message mentions --output (not --file)
if err == nil {
t.Fatal("expected error for absolute path")
}
if !strings.Contains(err.Error(), "--output") {
t.Errorf("error should mention --output, got: %s", err.Error())
}
}

View File

@@ -25,8 +25,6 @@ import (
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/credential"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
"github.com/spf13/cobra"
)
@@ -335,6 +333,39 @@ func (ctx *RuntimeContext) ResolveSavePath(path string) (string, error) {
return resolved, nil
}
// WrapSaveError matches a FileIO.Save error against known categories and wraps
// it with the caller-provided message prefix, preserving backward-compatible
// error text per shortcut.
func WrapSaveError(err error, pathMsg, mkdirMsg, writeMsg string) error {
if err == nil {
return nil
}
var me *fileio.MkdirError
var we *fileio.WriteError
switch {
case errors.Is(err, fileio.ErrPathValidation):
return fmt.Errorf("%s: %w", pathMsg, err)
case errors.As(err, &me):
return fmt.Errorf("%s: %w", mkdirMsg, err)
case errors.As(err, &we):
return fmt.Errorf("%s: %w", writeMsg, err)
default:
return fmt.Errorf("%s: %w", writeMsg, err)
}
}
// WrapOpenError matches a FileIO.Open/Stat error and wraps it with the
// caller-provided message prefix.
func WrapOpenError(err error, pathMsg, readMsg string) error {
if err == nil {
return nil
}
if errors.Is(err, fileio.ErrPathValidation) {
return fmt.Errorf("%s: %w", pathMsg, err)
}
return fmt.Errorf("%s: %w", readMsg, err)
}
// ValidatePath checks that path is a valid relative input path within the
// working directory by delegating to FileIO.Stat. Returns nil if the path is
// valid or does not exist yet; returns an error only for illegal paths
@@ -634,11 +665,15 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error {
if path == "" {
return FlagErrorf("--%s: file path cannot be empty after @", fl.Name)
}
safePath, err := validate.SafeInputPath(path)
f, err := rctx.FileIO().Open(path)
if err != nil {
return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err)
if errors.Is(err, fileio.ErrPathValidation) {
return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err)
}
return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err)
}
data, err := vfs.ReadFile(safePath)
data, err := io.ReadAll(f)
f.Close()
if err != nil {
return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err)
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/larksuite/cli/internal/cmdutil"
_ "github.com/larksuite/cli/internal/vfs/localfileio"
"github.com/spf13/cobra"
)

View File

@@ -13,6 +13,7 @@ import (
lark "github.com/larksuite/oapi-sdk-go/v3"
"github.com/spf13/cobra"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/cmdutil"
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/output"
@@ -102,6 +103,48 @@ func TestRuntimeContext_Out_WithJq_InvalidExpr_WritesStderr(t *testing.T) {
}
}
type testResolvedFileIO struct{}
func (testResolvedFileIO) Open(string) (fileio.File, error) { return nil, nil }
func (testResolvedFileIO) Stat(string) (fileio.FileInfo, error) { return nil, nil }
func (testResolvedFileIO) ResolvePath(path string) (string, error) { return path, nil }
func (testResolvedFileIO) Save(string, fileio.SaveOptions, io.Reader) (fileio.SaveResult, error) {
return nil, nil
}
type capturingFileIOProvider struct {
gotCtx context.Context
fileIO fileio.FileIO
}
func (p *capturingFileIOProvider) Name() string { return "capture" }
func (p *capturingFileIOProvider) ResolveFileIO(ctx context.Context) fileio.FileIO {
p.gotCtx = ctx
return p.fileIO
}
func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) {
execCtx := context.WithValue(context.Background(), "key", "value")
resolved := testResolvedFileIO{}
provider := &capturingFileIOProvider{fileIO: resolved}
rctx := &RuntimeContext{
ctx: execCtx,
Factory: &cmdutil.Factory{
FileIOProvider: provider,
},
}
got := rctx.FileIO()
if got != resolved {
t.Fatalf("FileIO() returned %T, want %T", got, resolved)
}
if provider.gotCtx != execCtx {
t.Fatal("ResolveFileIO() did not receive the runtime execution context")
}
}
func newTestShortcutCmd(s *Shortcut) *cobra.Command {
cmd := &cobra.Command{Use: "test-shortcut"}
cmd.SetContext(context.Background())
@@ -119,7 +162,8 @@ func newTestFactory() *cmdutil.Factory {
LarkClient: func() (*lark.Client, error) {
return lark.NewClient("test", "test"), nil
},
IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}},
IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}},
FileIOProvider: fileio.GetProvider(),
}
}

View File

@@ -16,6 +16,8 @@ import (
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/spf13/cobra"
"github.com/larksuite/cli/internal/cmdutil"
)
func TestSanitizeURLForDisplay(t *testing.T) {
@@ -404,39 +406,36 @@ func TestBuildSearchChatBodyAdditionalBranches(t *testing.T) {
func TestParseMediaDurationSuccess(t *testing.T) {
t.Run("mp4", func(t *testing.T) {
f, err := os.CreateTemp("", "im-duration-*.mp4")
if err != nil {
t.Fatalf("CreateTemp() error = %v", err)
cmdutil.TestChdir(t, t.TempDir())
fname := "im-duration-test.mp4"
if err := os.WriteFile(fname, wrapInMoov(buildMvhdBox(0, 1000, 5000)), 0644); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
if _, err := f.Write(wrapInMoov(buildMvhdBox(0, 1000, 5000))); err != nil {
t.Fatalf("Write() error = %v", err)
}
if got := parseMediaDuration(f.Name(), "mp4"); got != "5000" {
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("unexpected")
}))
if got := parseMediaDuration(rt, fname, "mp4"); got != "5000" {
t.Fatalf("parseMediaDuration(mp4) = %q, want %q", got, "5000")
}
})
t.Run("opus", func(t *testing.T) {
f, err := os.CreateTemp("", "im-duration-*.ogg")
if err != nil {
t.Fatalf("CreateTemp() error = %v", err)
}
defer os.Remove(f.Name())
defer f.Close()
cmdutil.TestChdir(t, t.TempDir())
page := make([]byte, 27)
copy(page[0:4], "OggS")
page[5] = 4
page[6] = 0x00
page[7] = 0x53
page[8] = 0x07
if _, err := f.Write(page); err != nil {
t.Fatalf("Write() error = %v", err)
fname := "im-duration-test.ogg"
if err := os.WriteFile(fname, page, 0644); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
if got := parseMediaDuration(f.Name(), "opus"); got != "10000" {
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("unexpected")
}))
if got := parseMediaDuration(rt, fname, "opus"); got != "10000" {
t.Fatalf("parseMediaDuration(opus) = %q, want %q", got, "10000")
}
})

View File

@@ -13,16 +13,15 @@ import (
"math"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
"github.com/larksuite/cli/shortcuts/common"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/spf13/cobra"
@@ -327,21 +326,16 @@ func resolveURLMedia(ctx context.Context, runtime *common.RuntimeContext, s medi
func resolveLocalMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) {
fmt.Fprintf(runtime.IO().ErrOut, "uploading %s: %s\n", s.mediaType, filepath.Base(s.value))
safePath, err := validate.SafeInputPath(s.value)
if err != nil {
return "", err
}
if s.kind == mediaKindImage {
return uploadImageToIM(ctx, runtime, safePath, "message")
return uploadImageToIM(ctx, runtime, s.value, "message")
}
ft := detectIMFileType(safePath)
ft := detectIMFileType(s.value)
dur := ""
if s.withDuration {
dur = parseMediaDuration(safePath, ft)
dur = parseMediaDuration(runtime, s.value, ft)
}
return uploadFileToIM(ctx, runtime, safePath, ft, dur)
return uploadFileToIM(ctx, runtime, s.value, ft, dur)
}
// resolveVideoContent handles the video case which needs both a file_key and
@@ -556,18 +550,16 @@ func findMP4Box(data []byte, start, end int, boxType string) (int, int) {
// for audio/video uploads. Only reads the minimal portion of the file needed
// for parsing (tail for OGG, box headers + moov for MP4).
// Returns "" if parsing fails or the file type is not audio/video.
func parseMediaDuration(filePath, fileType string) string {
func parseMediaDuration(runtime *common.RuntimeContext, filePath, fileType string) string {
if fileType != "opus" && fileType != "mp4" {
return ""
}
f, err := vfs.Open(filePath)
if err != nil {
info, err := runtime.FileIO().Stat(filePath)
if err != nil || info.Size() == 0 {
return ""
}
defer f.Close()
info, err := f.Stat()
if err != nil || info.Size() == 0 {
f, err := runtime.FileIO().Open(filePath)
if err != nil {
return ""
}
@@ -698,7 +690,7 @@ func readMp4DurationBytes(data []byte) int64 {
}
// readOggDuration reads the tail of an OGG file (up to 64 KB) and parses duration.
func readOggDuration(f *os.File, fileSize int64) int64 {
func readOggDuration(f fileio.File, fileSize int64) int64 {
const maxTail = 65536
readSize := fileSize
if readSize > maxTail {
@@ -713,7 +705,7 @@ func readOggDuration(f *os.File, fileSize int64) int64 {
// readMp4Duration walks top-level MP4 boxes via file seeks to find moov,
// then reads only the moov content to locate mvhd and extract the duration.
func readMp4Duration(f *os.File, fileSize int64) int64 {
func readMp4Duration(f fileio.File, fileSize int64) int64 {
hdr := make([]byte, 16)
var offset int64
for offset+8 <= fileSize {
@@ -1005,14 +997,11 @@ const maxImageUploadSize = 5 * 1024 * 1024 // 5MB — Lark API limit for images
const maxFileUploadSize = 100 * 1024 * 1024 // 100MB — Lark API limit for files
func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, imageType string) (string, error) {
// filePath is already validated by the caller (resolveLocalMedia).
safePath := filePath
if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxImageUploadSize {
if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxImageUploadSize {
return "", fmt.Errorf("image size %s exceeds limit (max 5MB)", common.FormatSize(info.Size()))
}
f, err := vfs.Open(safePath)
f, err := runtime.FileIO().Open(filePath)
if err != nil {
return "", err
}
@@ -1045,14 +1034,11 @@ func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePa
}
func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, fileType, duration string) (string, error) {
// filePath is already validated by the caller (resolveLocalMedia).
safePath := filePath
if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxFileUploadSize {
if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxFileUploadSize {
return "", fmt.Errorf("file size %s exceeds limit (max 100MB)", common.FormatSize(info.Size()))
}
f, err := vfs.Open(safePath)
f, err := runtime.FileIO().Open(filePath)
if err != nil {
return "", err
}
@@ -1060,7 +1046,7 @@ func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePat
fd := larkcore.NewFormdata()
fd.AddField("file_type", fileType)
fd.AddField("file_name", filepath.Base(safePath))
fd.AddField("file_name", filepath.Base(filePath))
if duration != "" {
fd.AddField("duration", duration)
}

View File

@@ -20,6 +20,7 @@ import (
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/cmdutil"
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/credential"
@@ -52,9 +53,10 @@ func shortcutRawResponse(status int, body []byte, headers http.Header) *http.Res
headers = make(http.Header)
}
return &http.Response{
StatusCode: status,
Header: headers,
Body: io.NopCloser(bytes.NewReader(body)),
StatusCode: status,
Header: headers,
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
}
@@ -88,10 +90,11 @@ func newBotShortcutRuntime(t *testing.T, rt http.RoundTripper) *common.RuntimeCo
runtime := &common.RuntimeContext{
Config: cfg,
Factory: &cmdutil.Factory{
Config: func() (*core.CliConfig, error) { return cfg, nil },
HttpClient: func() (*http.Client, error) { return httpClient, nil },
LarkClient: func() (*lark.Client, error) { return sdk, nil },
Credential: testCred,
Config: func() (*core.CliConfig, error) { return cfg, nil },
HttpClient: func() (*http.Client, error) { return httpClient, nil },
LarkClient: func() (*lark.Client, error) { return sdk, nil },
Credential: testCred,
FileIOProvider: fileio.GetProvider(),
IOStreams: &cmdutil.IOStreams{
Out: &bytes.Buffer{},
ErrOut: &bytes.Buffer{},
@@ -241,7 +244,9 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) {
}
}))
target := filepath.Join(t.TempDir(), "nested", "resource.bin")
cmdutil.TestChdir(t, t.TempDir())
target := filepath.Join("nested", "resource.bin")
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", target)
if err != nil {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
@@ -280,7 +285,9 @@ func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) {
}
}))
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", filepath.Join(t.TempDir(), "out.bin"))
cmdutil.TestChdir(t, t.TempDir())
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", "out.bin")
if err == nil || !strings.Contains(err.Error(), "HTTP 403: denied") {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
}
@@ -305,28 +312,14 @@ func TestUploadImageToIMSuccess(t *testing.T) {
}
}))
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
tmpDir := t.TempDir()
if err := os.Chdir(tmpDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
t.Cleanup(func() {
_ = os.Chdir(wd)
})
cmdutil.TestChdir(t, t.TempDir())
path := "demo.png"
if err := os.WriteFile(path, []byte("png"), 0600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
absPath, err := filepath.Abs(path)
if err != nil {
t.Fatalf("Abs() error = %v", err)
}
got, err := uploadImageToIM(context.Background(), runtime, absPath, "message")
got, err := uploadImageToIM(context.Background(), runtime, path, "message")
if err != nil {
t.Fatalf("uploadImageToIM() error = %v", err)
}
@@ -357,28 +350,14 @@ func TestUploadFileToIMSuccess(t *testing.T) {
}
}))
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
tmpDir := t.TempDir()
if err := os.Chdir(tmpDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
t.Cleanup(func() {
_ = os.Chdir(wd)
})
cmdutil.TestChdir(t, t.TempDir())
path := "demo.txt"
if err := os.WriteFile(path, []byte("demo"), 0600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
absPath, err := filepath.Abs(path)
if err != nil {
t.Fatalf("Abs() error = %v", err)
}
got, err := uploadFileToIM(context.Background(), runtime, absPath, "stream", "1200")
got, err := uploadFileToIM(context.Background(), runtime, path, "stream", "1200")
if err != nil {
t.Fatalf("uploadFileToIM() error = %v", err)
}
@@ -394,7 +373,8 @@ func TestUploadFileToIMSuccess(t *testing.T) {
}
func TestUploadImageToIMSizeLimit(t *testing.T) {
path := filepath.Join(t.TempDir(), "too-large.png")
cmdutil.TestChdir(t, t.TempDir())
path := "too-large.png"
f, err := os.Create(path)
if err != nil {
t.Fatalf("Create() error = %v", err)
@@ -404,14 +384,18 @@ func TestUploadImageToIMSizeLimit(t *testing.T) {
}
f.Close()
_, err = uploadImageToIM(context.Background(), nil, path, "message")
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("unexpected")
}))
_, err = uploadImageToIM(context.Background(), rt, path, "message")
if err == nil || !strings.Contains(err.Error(), "exceeds limit") {
t.Fatalf("uploadImageToIM() error = %v", err)
}
}
func TestUploadFileToIMSizeLimit(t *testing.T) {
path := filepath.Join(t.TempDir(), "too-large.bin")
cmdutil.TestChdir(t, t.TempDir())
path := "too-large.bin"
f, err := os.Create(path)
if err != nil {
t.Fatalf("Create() error = %v", err)
@@ -421,7 +405,10 @@ func TestUploadFileToIMSizeLimit(t *testing.T) {
}
f.Close()
_, err = uploadFileToIM(context.Background(), nil, path, "stream", "")
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("unexpected")
}))
_, err = uploadFileToIM(context.Background(), rt, path, "stream", "")
if err == nil || !strings.Contains(err.Error(), "exceeds limit") {
t.Fatalf("uploadFileToIM() error = %v", err)
}
@@ -430,6 +417,7 @@ func TestUploadFileToIMSizeLimit(t *testing.T) {
func TestResolveMediaContentWrapsUploadError(t *testing.T) {
runtime := &common.RuntimeContext{
Factory: &cmdutil.Factory{
FileIOProvider: fileio.GetProvider(),
IOStreams: &cmdutil.IOStreams{
Out: &bytes.Buffer{},
ErrOut: &bytes.Buffer{},
@@ -437,7 +425,9 @@ func TestResolveMediaContentWrapsUploadError(t *testing.T) {
},
}
missing := filepath.Join(t.TempDir(), "missing.png")
cmdutil.TestChdir(t, t.TempDir())
missing := "missing.png"
_, _, err := resolveMediaContent(context.Background(), runtime, "", missing, "", "", "", "")
if err == nil || !strings.Contains(err.Error(), "image upload failed") {
t.Fatalf("resolveMediaContent() error = %v", err)
@@ -457,15 +447,7 @@ func TestResolveLocalMediaImage(t *testing.T) {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}))
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
tmpDir := t.TempDir()
if err := os.Chdir(tmpDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
t.Cleanup(func() { _ = os.Chdir(wd) })
cmdutil.TestChdir(t, t.TempDir())
if err := os.WriteFile("test.png", []byte("png-data"), 0600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
@@ -496,15 +478,7 @@ func TestResolveLocalMediaFile(t *testing.T) {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}))
wd, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
tmpDir := t.TempDir()
if err := os.Chdir(tmpDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
t.Cleanup(func() { _ = os.Chdir(wd) })
cmdutil.TestChdir(t, t.TempDir())
if err := os.WriteFile("test.txt", []byte("file-data"), 0600); err != nil {
t.Fatalf("WriteFile() error = %v", err)

View File

@@ -7,6 +7,7 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"net/http"
"reflect"
"strings"
@@ -263,10 +264,13 @@ func TestParseMp4Duration(t *testing.T) {
}
func TestParseMediaDuration(t *testing.T) {
if got := parseMediaDuration("test.pdf", "pdf"); got != "" {
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("unexpected")
}))
if got := parseMediaDuration(rt, "test.pdf", "pdf"); got != "" {
t.Fatalf("parseMediaDuration(pdf) = %q, want empty", got)
}
if got := parseMediaDuration("nonexistent.opus", "opus"); got != "" {
if got := parseMediaDuration(rt, "nonexistent.opus", "opus"); got != "" {
t.Fatalf("parseMediaDuration(missing) = %q, want empty", got)
}
}

View File

@@ -91,29 +91,13 @@ var ImMessagesReply = common.Shortcut{
videoCoverKey := runtime.Str("video-cover")
audioKey := runtime.Str("audio")
if !isMediaKey(imageKey) {
if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(fileKey) {
if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoKey) {
if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoCoverKey) {
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(audioKey) {
if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil {
return output.ErrValidation("%v", err)
fio := runtime.FileIO()
for _, mf := range []struct{ flag, val string }{
{"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey},
{"--video-cover", videoCoverKey}, {"--audio", audioKey},
} {
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
return err
}
}
@@ -149,29 +133,13 @@ var ImMessagesReply = common.Shortcut{
audioVal := runtime.Str("audio")
replyInThread := runtime.Bool("reply-in-thread")
idempotencyKey := runtime.Str("idempotency-key")
if !isMediaKey(imageVal) {
if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(fileVal) {
if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoVal) {
if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoCoverVal) {
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(audioVal) {
if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil {
return output.ErrValidation("%v", err)
fio := runtime.FileIO()
for _, mf := range []struct{ flag, val string }{
{"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal},
{"--video-cover", videoCoverVal}, {"--audio", audioVal},
} {
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
return err
}
}

View File

@@ -6,15 +6,15 @@ package im
import (
"context"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"time"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/client"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/internal/vfs"
"github.com/larksuite/cli/shortcuts/common"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
)
@@ -54,7 +54,7 @@ var ImMessagesResourcesDownload = common.Shortcut{
if err != nil {
return output.ErrValidation("%s", err)
}
if _, err := validate.SafeOutputPath(relPath); err != nil {
if _, err := runtime.ResolveSavePath(relPath); err != nil {
return output.ErrValidation("unsafe output path: %s", err)
}
return nil
@@ -67,12 +67,8 @@ var ImMessagesResourcesDownload = common.Shortcut{
if err != nil {
return output.ErrValidation("invalid output path: %s", err)
}
safePath, err := validate.SafeOutputPath(relPath)
if err != nil {
return output.ErrValidation("unsafe output path: %s", err)
}
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, safePath)
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, relPath)
if err != nil {
return err
}
@@ -109,33 +105,33 @@ func normalizeDownloadOutputPath(fileKey, outputPath string) (string, error) {
const defaultIMResourceDownloadTimeout = 120 * time.Second
var imMimeToExt = map[string]string{
"image/png": ".png",
"image/jpeg": ".jpg",
"image/gif": ".gif",
"image/webp": ".webp",
"image/svg+xml": ".svg",
"application/pdf": ".pdf",
"video/mp4": ".mp4",
"video/3gpp": ".3gp",
"video/x-msvideo": ".avi",
"audio/mpeg": ".mp3",
"audio/ogg": ".ogg",
"audio/wav": ".wav",
"text/plain": ".txt",
"text/html": ".html",
"text/css": ".css",
"text/csv": ".csv",
"application/zip": ".zip",
"image/png": ".png",
"image/jpeg": ".jpg",
"image/gif": ".gif",
"image/webp": ".webp",
"image/svg+xml": ".svg",
"application/pdf": ".pdf",
"video/mp4": ".mp4",
"video/3gpp": ".3gp",
"video/x-msvideo": ".avi",
"audio/mpeg": ".mp3",
"audio/ogg": ".ogg",
"audio/wav": ".wav",
"text/plain": ".txt",
"text/html": ".html",
"text/css": ".css",
"text/csv": ".csv",
"application/zip": ".zip",
"application/x-zip-compressed": ".zip",
"application/x-rar-compressed": ".rar",
"application/json": ".json",
"application/xml": ".xml",
"application/octet-stream": ".bin",
"application/msword": ".doc",
"application/json": ".json",
"application/xml": ".xml",
"application/octet-stream": ".bin",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-powerpoint": ".ppt",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-powerpoint": ".ppt",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
}
@@ -156,8 +152,12 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex
}
defer downloadResp.Body.Close()
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err)
if downloadResp.StatusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(downloadResp.Body, 4096))
if len(body) > 0 {
return "", 0, output.ErrNetwork("download failed: HTTP %d: %s", downloadResp.StatusCode, strings.TrimSpace(string(body)))
}
return "", 0, output.ErrNetwork("download failed: HTTP %d", downloadResp.StatusCode)
}
// Auto-detect extension from Content-Type if missing
@@ -171,9 +171,19 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex
}
}
sizeBytes, err := validate.AtomicWriteFromReader(finalPath, downloadResp.Body, 0600)
result, err := runtime.FileIO().Save(finalPath, fileio.SaveOptions{
ContentType: downloadResp.Header.Get("Content-Type"),
ContentLength: downloadResp.ContentLength,
}, downloadResp.Body)
if err != nil {
return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create file: %s", err)
return "", 0, output.Errorf(output.ExitInternal, "api_error", "%s",
common.WrapSaveError(err, "unsafe output path", "cannot create parent directory", "cannot create file"))
}
return finalPath, sizeBytes, nil
savedPath, resolveErr := runtime.ResolveSavePath(finalPath)
if resolveErr != nil {
// Save succeeded — file is on disk. Fall back to the relative path
// rather than returning an error for a successfully written file.
savedPath = finalPath
}
return savedPath, result.Size(), nil
}

View File

@@ -7,10 +7,11 @@ import (
"context"
"encoding/json"
"net/http"
"os"
"strings"
"github.com/larksuite/cli/extension/fileio"
"github.com/larksuite/cli/internal/output"
"github.com/larksuite/cli/internal/validate"
"github.com/larksuite/cli/shortcuts/common"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
)
@@ -98,29 +99,13 @@ var ImMessagesSend = common.Shortcut{
videoCoverKey := runtime.Str("video-cover")
audioKey := runtime.Str("audio")
if !isMediaKey(imageKey) {
if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(fileKey) {
if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoKey) {
if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoCoverKey) {
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(audioKey) {
if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil {
return output.ErrValidation("%v", err)
fio := runtime.FileIO()
for _, mf := range []struct{ flag, val string }{
{"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey},
{"--video-cover", videoCoverKey}, {"--audio", audioKey},
} {
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
return err
}
}
@@ -165,29 +150,13 @@ var ImMessagesSend = common.Shortcut{
videoVal := runtime.Str("video")
videoCoverVal := runtime.Str("video-cover")
audioVal := runtime.Str("audio")
if !isMediaKey(imageVal) {
if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(fileVal) {
if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoVal) {
if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(videoCoverVal) {
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil {
return output.ErrValidation("%v", err)
}
}
if !isMediaKey(audioVal) {
if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil {
return output.ErrValidation("%v", err)
fio := runtime.FileIO()
for _, mf := range []struct{ flag, val string }{
{"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal},
{"--video-cover", videoCoverVal}, {"--audio", audioVal},
} {
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
return err
}
}
// Resolve content type
@@ -239,3 +208,15 @@ var ImMessagesSend = common.Shortcut{
func isMediaKey(value string) bool {
return strings.HasPrefix(value, "img_") || strings.HasPrefix(value, "file_")
}
// validateMediaFlagPath validates a media flag value as a local file path via FileIO.
// Empty values, URLs, and media keys are skipped (not local files).
func validateMediaFlagPath(fio fileio.FileIO, flagName, value string) error {
if value == "" || strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") || isMediaKey(value) {
return nil
}
if _, err := fio.Stat(value); err != nil && !os.IsNotExist(err) {
return output.ErrValidation("%s: %v", flagName, err)
}
return nil
}

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package im
import (
"os"
"path/filepath"
"testing"
"github.com/larksuite/cli/internal/vfs/localfileio"
)
func TestValidateMediaFlagPath(t *testing.T) {
dir := t.TempDir()
orig, _ := os.Getwd()
defer os.Chdir(orig)
os.Chdir(dir)
os.WriteFile(filepath.Join(dir, "photo.jpg"), []byte("img"), 0644)
fio := &localfileio.LocalFileIO{}
tests := []struct {
name string
flag string
value string
wantErr bool
}{
{"empty value skipped", "--image", "", false},
{"http URL skipped", "--image", "http://example.com/a.jpg", false},
{"https URL skipped", "--file", "https://example.com/b.mp4", false},
{"media key skipped", "--image", "img_abc123", false},
{"file key skipped", "--file", "file_abc123", false},
{"valid local file", "--image", "photo.jpg", false},
{"nonexistent file allowed", "--file", "missing.txt", false},
{"path traversal rejected", "--image", "../../etc/passwd", true},
{"absolute path rejected", "--file", "/etc/passwd", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateMediaFlagPath(fio, tt.flag, tt.value)
if tt.wantErr && err == nil {
t.Fatalf("expected error for %s=%q, got nil", tt.flag, tt.value)
}
if !tt.wantErr && err != nil {
t.Fatalf("unexpected error for %s=%q: %v", tt.flag, tt.value, err)
}
})
}
}