mirror of
https://github.com/larksuite/cli.git
synced 2026-07-03 14:02:43 +08:00
refactor: migrate common/client/im to FileIO and add localfileio tests (#322)
* refactor: migrate common/client/im to FileIO and add localfileio tests - runner resolveInputFlags: replace validate.SafeInputPath + vfs.ReadFile with FileIO.Open + io.ReadAll - SaveResponse: delegate to FileIO.Save + ResolvePath - cmd/api, cmd/service: pass FileIO to ResponseOptions - im: replace validate.SafeLocalFlagPath with RuntimeContext.ValidatePath, migrate download/upload to FileIO.Save/Open/Stat - Add path_test.go and atomicwrite_test.go for localfileio - Add validate_media_test.go for im media flag validation - Adapt test mocks to fileio.FileInfo interface
This commit is contained in:
@@ -207,6 +207,7 @@ func apiRun(opts *APIOptions) error {
|
||||
JqExpr: opts.JqExpr,
|
||||
Out: out,
|
||||
ErrOut: f.IOStreams.ErrOut,
|
||||
FileIO: f.ResolveFileIO(opts.Ctx),
|
||||
})
|
||||
// MarkRaw tells root error handler to skip enrichPermissionError,
|
||||
// preserving the original API error detail (log_id, troubleshooter, etc.).
|
||||
|
||||
@@ -250,6 +250,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error {
|
||||
JqExpr: opts.JqExpr,
|
||||
Out: out,
|
||||
ErrOut: f.IOStreams.ErrOut,
|
||||
FileIO: f.ResolveFileIO(opts.Ctx),
|
||||
CheckError: checkErr,
|
||||
})
|
||||
}
|
||||
|
||||
40
extension/fileio/errors.go
Normal file
40
extension/fileio/errors.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package fileio
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrPathValidation indicates the path failed security validation
|
||||
// (traversal, absolute, control chars, symlink escape, etc.).
|
||||
var ErrPathValidation = errors.New("path validation failed")
|
||||
|
||||
// PathValidationError wraps a path validation error.
|
||||
// errors.Is(err, ErrPathValidation) returns true.
|
||||
// errors.Is(err, <original OS error>) also works via the chain.
|
||||
type PathValidationError struct {
|
||||
Err error // original error
|
||||
}
|
||||
|
||||
func (e *PathValidationError) Error() string { return e.Err.Error() }
|
||||
func (e *PathValidationError) Unwrap() []error {
|
||||
return []error{ErrPathValidation, e.Err}
|
||||
}
|
||||
|
||||
// MkdirError indicates parent directory creation failed.
|
||||
// Use errors.As(err, &fileio.MkdirError{}) to match.
|
||||
type MkdirError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *MkdirError) Error() string { return e.Err.Error() }
|
||||
func (e *MkdirError) Unwrap() error { return e.Err }
|
||||
|
||||
// WriteError indicates file write failed.
|
||||
// Use errors.As(err, &fileio.WriteError{}) to match.
|
||||
type WriteError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *WriteError) Error() string { return e.Err.Error() }
|
||||
func (e *WriteError) Unwrap() error { return e.Err }
|
||||
@@ -6,18 +6,17 @@ package client
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/util"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
)
|
||||
|
||||
// ── Response routing ──
|
||||
@@ -29,6 +28,7 @@ type ResponseOptions struct {
|
||||
JqExpr string // if set, apply jq filter instead of Format
|
||||
Out io.Writer // stdout
|
||||
ErrOut io.Writer // stderr
|
||||
FileIO fileio.FileIO // file transfer abstraction; required when saving files (--output or binary response)
|
||||
// CheckError is called on parsed JSON results. Nil defaults to CheckLarkResponse.
|
||||
CheckError func(interface{}) error
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
|
||||
return apiErr
|
||||
}
|
||||
if opts.OutputPath != "" {
|
||||
return saveAndPrint(resp, opts.OutputPath, opts.Out)
|
||||
return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out)
|
||||
}
|
||||
if opts.JqExpr != "" {
|
||||
return output.JqFilter(opts.Out, result, opts.JqExpr)
|
||||
@@ -75,11 +75,11 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
|
||||
return output.ErrValidation("--jq requires a JSON response (got Content-Type: %s)", ct)
|
||||
}
|
||||
if opts.OutputPath != "" {
|
||||
return saveAndPrint(resp, opts.OutputPath, opts.Out)
|
||||
return saveAndPrint(opts.FileIO, resp, opts.OutputPath, opts.Out)
|
||||
}
|
||||
|
||||
// No --output: auto-save with derived filename.
|
||||
meta, err := SaveResponse(resp, ResolveFilename(resp))
|
||||
meta, err := SaveResponse(opts.FileIO, resp, ResolveFilename(resp))
|
||||
if err != nil {
|
||||
return output.Errorf(output.ExitInternal, "file_error", "%s", err)
|
||||
}
|
||||
@@ -88,8 +88,8 @@ func HandleResponse(resp *larkcore.ApiResp, opts ResponseOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func saveAndPrint(resp *larkcore.ApiResp, path string, w io.Writer) error {
|
||||
meta, err := SaveResponse(resp, path)
|
||||
func saveAndPrint(fio fileio.FileIO, resp *larkcore.ApiResp, path string, w io.Writer) error {
|
||||
meta, err := SaveResponse(fio, resp, path)
|
||||
if err != nil {
|
||||
return output.Errorf(output.ExitInternal, "file_error", "%s", err)
|
||||
}
|
||||
@@ -119,23 +119,34 @@ func ParseJSONResponse(resp *larkcore.ApiResp) (interface{}, error) {
|
||||
// ── File saving ──
|
||||
|
||||
// SaveResponse writes an API response body to the given outputPath and returns metadata.
|
||||
func SaveResponse(resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) {
|
||||
safePath, err := validate.SafeOutputPath(outputPath)
|
||||
// It delegates to FileIO.Save for path validation and atomic write; fio must not be nil.
|
||||
func SaveResponse(fio fileio.FileIO, resp *larkcore.ApiResp, outputPath string) (map[string]interface{}, error) {
|
||||
result, err := fio.Save(outputPath, fileio.SaveOptions{
|
||||
ContentType: resp.Header.Get("Content-Type"),
|
||||
ContentLength: int64(len(resp.RawBody)),
|
||||
}, bytes.NewReader(resp.RawBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unsafe output path: %s", err)
|
||||
var me *fileio.MkdirError
|
||||
var we *fileio.WriteError
|
||||
switch {
|
||||
case errors.Is(err, fileio.ErrPathValidation):
|
||||
return nil, fmt.Errorf("unsafe output path: %s", err)
|
||||
case errors.As(err, &me):
|
||||
return nil, fmt.Errorf("create directory: %s", err)
|
||||
case errors.As(err, &we):
|
||||
return nil, fmt.Errorf("cannot write file: %s", err)
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot write file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
|
||||
return nil, fmt.Errorf("create directory: %s", err)
|
||||
resolvedPath, err := fio.ResolvePath(outputPath)
|
||||
if err != nil || resolvedPath == "" {
|
||||
resolvedPath = outputPath
|
||||
}
|
||||
|
||||
if err := validate.AtomicWrite(safePath, resp.RawBody, 0644); err != nil {
|
||||
return nil, fmt.Errorf("cannot write file: %s", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"saved_path": safePath,
|
||||
"size_bytes": len(resp.RawBody),
|
||||
"saved_path": resolvedPath,
|
||||
"size_bytes": result.Size(),
|
||||
"content_type": resp.Header.Get("Content-Type"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/vfs/localfileio"
|
||||
)
|
||||
|
||||
func newApiResp(body []byte, headers map[string]string) *larkcore.ApiResp {
|
||||
@@ -162,11 +163,11 @@ func TestSaveResponse(t *testing.T) {
|
||||
body := []byte("hello binary data")
|
||||
resp := newApiResp(body, map[string]string{"Content-Type": "application/octet-stream"})
|
||||
|
||||
meta, err := SaveResponse(resp, "test_output.bin")
|
||||
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "test_output.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("SaveResponse failed: %v", err)
|
||||
}
|
||||
if meta["size_bytes"] != len(body) {
|
||||
if meta["size_bytes"] != int64(len(body)) {
|
||||
t.Errorf("expected size_bytes=%d, got %v", len(body), meta["size_bytes"])
|
||||
}
|
||||
|
||||
@@ -188,7 +189,7 @@ func TestSaveResponse_CreatesDir(t *testing.T) {
|
||||
|
||||
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
|
||||
|
||||
meta, err := SaveResponse(resp, filepath.Join("sub", "deep", "out.bin"))
|
||||
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, filepath.Join("sub", "deep", "out.bin"))
|
||||
if err != nil {
|
||||
t.Fatalf("SaveResponse with nested dir failed: %v", err)
|
||||
}
|
||||
@@ -207,6 +208,7 @@ func TestHandleResponse_JSON(t *testing.T) {
|
||||
err := HandleResponse(resp, ResponseOptions{
|
||||
Out: &out,
|
||||
ErrOut: &errOut,
|
||||
FileIO: &localfileio.LocalFileIO{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleResponse failed: %v", err)
|
||||
@@ -225,6 +227,7 @@ func TestHandleResponse_JSONWithError(t *testing.T) {
|
||||
err := HandleResponse(resp, ResponseOptions{
|
||||
Out: &out,
|
||||
ErrOut: &errOut,
|
||||
FileIO: &localfileio.LocalFileIO{},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for non-zero code")
|
||||
@@ -275,6 +278,7 @@ func TestHandleResponse_BinaryAutoSave(t *testing.T) {
|
||||
err := HandleResponse(resp, ResponseOptions{
|
||||
Out: &out,
|
||||
ErrOut: &errOut,
|
||||
FileIO: &localfileio.LocalFileIO{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleResponse binary failed: %v", err)
|
||||
@@ -298,6 +302,7 @@ func TestHandleResponse_BinaryWithOutput(t *testing.T) {
|
||||
OutputPath: "out.png",
|
||||
Out: &out,
|
||||
ErrOut: &errOut,
|
||||
FileIO: &localfileio.LocalFileIO{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleResponse with output path failed: %v", err)
|
||||
@@ -312,7 +317,7 @@ func TestHandleResponse_NonJSONError_404(t *testing.T) {
|
||||
resp := newApiRespWithStatus(404, []byte("404 page not found"), map[string]string{"Content-Type": "text/plain"})
|
||||
|
||||
var out, errOut bytes.Buffer
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404 text/plain")
|
||||
}
|
||||
@@ -330,7 +335,7 @@ func TestHandleResponse_NonJSONError_502(t *testing.T) {
|
||||
resp := newApiRespWithStatus(502, []byte("<html>Bad Gateway</html>"), map[string]string{"Content-Type": "text/html"})
|
||||
|
||||
var out, errOut bytes.Buffer
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 502 text/html")
|
||||
}
|
||||
@@ -353,7 +358,7 @@ func TestHandleResponse_200TextPlain_SavesFile(t *testing.T) {
|
||||
resp := newApiRespWithStatus(200, []byte("plain text file content"), map[string]string{"Content-Type": "text/plain"})
|
||||
|
||||
var out, errOut bytes.Buffer
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for 200 text/plain, got: %v", err)
|
||||
}
|
||||
@@ -379,12 +384,53 @@ func TestHandleResponse_BinaryWithJq_RejectsNonJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveResponse_RejectsPathTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origWd, _ := os.Getwd()
|
||||
os.Chdir(dir)
|
||||
defer os.Chdir(origWd)
|
||||
|
||||
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
|
||||
_, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "../../evil.txt")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path traversal")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsafe output path") {
|
||||
t.Errorf("expected 'unsafe output path' wrapper, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveResponse_RejectsAbsolutePath(t *testing.T) {
|
||||
resp := newApiResp([]byte("data"), map[string]string{"Content-Type": "application/octet-stream"})
|
||||
_, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "/tmp/evil.txt")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for absolute path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveResponse_MetadataContainsAbsolutePath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origWd, _ := os.Getwd()
|
||||
os.Chdir(dir)
|
||||
defer os.Chdir(origWd)
|
||||
|
||||
resp := newApiResp([]byte("x"), map[string]string{"Content-Type": "text/plain"})
|
||||
meta, err := SaveResponse(&localfileio.LocalFileIO{}, resp, "rel.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("SaveResponse failed: %v", err)
|
||||
}
|
||||
savedPath, _ := meta["saved_path"].(string)
|
||||
if !filepath.IsAbs(savedPath) {
|
||||
t.Errorf("saved_path should be absolute, got %q", savedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponse_403JSON_CheckLarkResponse(t *testing.T) {
|
||||
body := []byte(`{"code":99991400,"msg":"invalid token"}`)
|
||||
resp := newApiRespWithStatus(403, body, map[string]string{"Content-Type": "application/json"})
|
||||
|
||||
var out, errOut bytes.Buffer
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut})
|
||||
err := HandleResponse(resp, ResponseOptions{Out: &out, ErrOut: &errOut, FileIO: &localfileio.LocalFileIO{}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 JSON with non-zero code")
|
||||
}
|
||||
|
||||
@@ -11,13 +11,26 @@ import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/larksuite/cli/extension/credential/env"
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
exttransport "github.com/larksuite/cli/extension/transport"
|
||||
internalauth "github.com/larksuite/cli/internal/auth"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
"github.com/larksuite/cli/internal/envvars"
|
||||
"github.com/larksuite/cli/internal/vfs/localfileio"
|
||||
)
|
||||
|
||||
type countingFileIOProvider struct {
|
||||
resolveCalls int
|
||||
}
|
||||
|
||||
func (p *countingFileIOProvider) Name() string { return "counting" }
|
||||
|
||||
func (p *countingFileIOProvider) ResolveFileIO(context.Context) fileio.FileIO {
|
||||
p.resolveCalls++
|
||||
return &localfileio.LocalFileIO{}
|
||||
}
|
||||
|
||||
func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) {
|
||||
t.Setenv(envvars.CliAppID, "")
|
||||
t.Setenv(envvars.CliAppSecret, "")
|
||||
@@ -198,6 +211,28 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing.T) {
|
||||
prev := fileio.GetProvider()
|
||||
provider := &countingFileIOProvider{}
|
||||
fileio.Register(provider)
|
||||
t.Cleanup(func() { fileio.Register(prev) })
|
||||
|
||||
f := NewDefault(InvocationContext{})
|
||||
if f.FileIOProvider != provider {
|
||||
t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider)
|
||||
}
|
||||
if provider.resolveCalls != 0 {
|
||||
t.Fatalf("ResolveFileIO() calls after NewDefault() = %d, want 0", provider.resolveCalls)
|
||||
}
|
||||
|
||||
if got := f.ResolveFileIO(context.Background()); got == nil {
|
||||
t.Fatal("ResolveFileIO() = nil, want non-nil")
|
||||
}
|
||||
if provider.resolveCalls != 1 {
|
||||
t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls)
|
||||
}
|
||||
}
|
||||
|
||||
type stubTransportProvider struct {
|
||||
interceptor exttransport.Interceptor
|
||||
}
|
||||
|
||||
146
internal/vfs/localfileio/atomicwrite_test.go
Normal file
146
internal/vfs/localfileio/atomicwrite_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package localfileio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAtomicWrite_WritesContentAndPermissionCorrectly(t *testing.T) {
|
||||
// GIVEN: a target path in a temp directory
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
data := []byte(`{"key":"value"}`)
|
||||
|
||||
// WHEN: AtomicWrite writes data with 0644 permission
|
||||
if err := AtomicWrite(path, data, 0644); err != nil {
|
||||
t.Fatalf("AtomicWrite failed: %v", err)
|
||||
}
|
||||
|
||||
// THEN: file content matches exactly
|
||||
got, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content = %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicWrite_SetsRestrictivePermission(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("permission test not reliable on Windows")
|
||||
}
|
||||
|
||||
// GIVEN: a target path
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "secret.json")
|
||||
|
||||
// WHEN: AtomicWrite writes with 0600 permission
|
||||
if err := AtomicWrite(path, []byte("secret"), 0600); err != nil {
|
||||
t.Fatalf("AtomicWrite failed: %v", err)
|
||||
}
|
||||
|
||||
// THEN: file permission is exactly 0600 (owner read-write only)
|
||||
info, _ := os.Stat(path)
|
||||
if perm := info.Mode().Perm(); perm != 0600 {
|
||||
t.Errorf("permission = %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicWrite_OverwritesExistingFile(t *testing.T) {
|
||||
// GIVEN: an existing file with old content
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
AtomicWrite(path, []byte("old"), 0644)
|
||||
|
||||
// WHEN: AtomicWrite overwrites with new content
|
||||
if err := AtomicWrite(path, []byte("new"), 0644); err != nil {
|
||||
t.Fatalf("second write failed: %v", err)
|
||||
}
|
||||
|
||||
// THEN: file contains new content
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "new" {
|
||||
t.Errorf("content = %q, want %q", got, "new")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicWrite_LeavesNoResidualTempFileOnError(t *testing.T) {
|
||||
// GIVEN: a target path in a non-existent nested directory
|
||||
path := filepath.Join(t.TempDir(), "nonexistent", "subdir", "file.txt")
|
||||
|
||||
// WHEN: AtomicWrite fails (parent directory doesn't exist)
|
||||
err := AtomicWrite(path, []byte("data"), 0644)
|
||||
|
||||
// THEN: the write fails
|
||||
if err == nil {
|
||||
t.Fatal("expected error writing to nonexistent dir")
|
||||
}
|
||||
|
||||
// THEN: no .tmp files are left behind
|
||||
parentDir := filepath.Dir(filepath.Dir(path))
|
||||
entries, _ := os.ReadDir(parentDir)
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".tmp" {
|
||||
t.Errorf("residual temp file found: %s", e.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicWrite_PreservesOriginalFileOnFailure(t *testing.T) {
|
||||
// GIVEN: an existing file with known content
|
||||
dir := t.TempDir()
|
||||
original := []byte("original content")
|
||||
path := filepath.Join(dir, "file.json")
|
||||
if err := AtomicWrite(path, original, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// WHEN: AtomicWrite targets a non-existent directory (guaranteed to fail even as root)
|
||||
badPath := filepath.Join(dir, "no", "such", "dir", "file.json")
|
||||
err := AtomicWrite(badPath, []byte("new"), 0644)
|
||||
|
||||
// THEN: write fails
|
||||
if err == nil {
|
||||
t.Fatal("expected error writing to non-existent dir")
|
||||
}
|
||||
|
||||
// THEN: the original file at the valid path is untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != string(original) {
|
||||
t.Errorf("original file corrupted: got %q, want %q", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicWrite_HandlesCorrectlyUnderConcurrentWrites(t *testing.T) {
|
||||
// GIVEN: a target file that will be written by 20 concurrent goroutines
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "concurrent.json")
|
||||
|
||||
// WHEN: 20 goroutines write simultaneously
|
||||
var wg sync.WaitGroup
|
||||
for i := range 20 {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
data := []byte(`{"n":` + string(rune('0'+n%10)) + `}`)
|
||||
AtomicWrite(path, data, 0644)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// THEN: file exists and is valid (not corrupted by interleaved writes)
|
||||
got, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
t.Error("file is empty after concurrent writes")
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ type LocalFileIO struct{}
|
||||
func (l *LocalFileIO) Open(name string) (fileio.File, error) {
|
||||
safePath, err := SafeInputPath(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &fileio.PathValidationError{Err: err}
|
||||
}
|
||||
return vfs.Open(safePath)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (l *LocalFileIO) Open(name string) (fileio.File, error) {
|
||||
func (l *LocalFileIO) Stat(name string) (fileio.FileInfo, error) {
|
||||
safePath, err := SafeInputPath(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &fileio.PathValidationError{Err: err}
|
||||
}
|
||||
return vfs.Stat(safePath)
|
||||
}
|
||||
@@ -55,7 +55,11 @@ func (r *saveResult) Size() int64 { return r.size }
|
||||
|
||||
// ResolvePath returns the validated absolute path for the given output path.
|
||||
func (l *LocalFileIO) ResolvePath(path string) (string, error) {
|
||||
return SafeOutputPath(path)
|
||||
resolved, err := SafeOutputPath(path)
|
||||
if err != nil {
|
||||
return "", &fileio.PathValidationError{Err: err}
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// Save writes body to path atomically after validating the output path.
|
||||
@@ -64,14 +68,14 @@ func (l *LocalFileIO) ResolvePath(path string) (string, error) {
|
||||
func (l *LocalFileIO) Save(path string, _ fileio.SaveOptions, body io.Reader) (fileio.SaveResult, error) {
|
||||
safePath, err := SafeOutputPath(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &fileio.PathValidationError{Err: err}
|
||||
}
|
||||
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
|
||||
return nil, err
|
||||
return nil, &fileio.MkdirError{Err: err}
|
||||
}
|
||||
n, err := AtomicWriteFromReader(safePath, body, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &fileio.WriteError{Err: err}
|
||||
}
|
||||
return &saveResult{size: n}, nil
|
||||
}
|
||||
|
||||
306
internal/vfs/localfileio/localfileio_test.go
Normal file
306
internal/vfs/localfileio/localfileio_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package localfileio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
)
|
||||
|
||||
// testChdir temporarily changes the working directory for a test.
|
||||
// Not compatible with t.Parallel().
|
||||
func testChdir(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
orig, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Chdir(dir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.Chdir(orig) })
|
||||
}
|
||||
|
||||
// ── Provider ──
|
||||
|
||||
func TestProvider_Name(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if got := p.Name(); got != "local" {
|
||||
t.Errorf("Provider.Name() = %q, want %q", got, "local")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ResolveFileIO(t *testing.T) {
|
||||
p := &Provider{}
|
||||
fio := p.ResolveFileIO(nil)
|
||||
if fio == nil {
|
||||
t.Fatal("Provider.ResolveFileIO returned nil")
|
||||
}
|
||||
if _, ok := fio.(*LocalFileIO); !ok {
|
||||
t.Errorf("expected *LocalFileIO, got %T", fio)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Open ──
|
||||
|
||||
func TestLocalFileIO_Open_ValidFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
content := []byte("hello world")
|
||||
os.WriteFile("test.txt", content, 0644)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
f, err := fio.Open("test.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Open failed: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
got, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll failed: %v", err)
|
||||
}
|
||||
if string(got) != string(content) {
|
||||
t.Errorf("content = %q, want %q", got, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Open_RejectsTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Open("../../etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for path traversal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Open_RejectsAbsolutePath(t *testing.T) {
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Open("/etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for absolute path")
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "relative path") {
|
||||
t.Errorf("error should mention relative path, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Open_NonexistentFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Open("nonexistent.txt")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Stat ──
|
||||
|
||||
func TestLocalFileIO_Stat_ValidFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
os.WriteFile("stat.txt", []byte("12345"), 0644)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
info, err := fio.Stat("stat.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
if info.Size() != 5 {
|
||||
t.Errorf("Size() = %d, want 5", info.Size())
|
||||
}
|
||||
if info.IsDir() {
|
||||
t.Error("expected IsDir() = false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Stat_RejectsTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Stat("../../etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for path traversal")
|
||||
}
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
t.Error("traversal should not be os.IsNotExist, should be a validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Stat_NonexistentReturnsIsNotExist(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Stat("nope.txt")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Errorf("expected os.IsNotExist, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Save ──
|
||||
|
||||
func TestLocalFileIO_Save_WritesContent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
body := strings.NewReader("saved content")
|
||||
result, err := fio.Save("output.bin", fileio.SaveOptions{}, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Save failed: %v", err)
|
||||
}
|
||||
if result.Size() != int64(len("saved content")) {
|
||||
t.Errorf("Size() = %d, want %d", result.Size(), len("saved content"))
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(filepath.Join(dir, "output.bin"))
|
||||
if string(got) != "saved content" {
|
||||
t.Errorf("file content = %q, want %q", got, "saved content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Save_CreatesParentDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
body := strings.NewReader("nested")
|
||||
_, err := fio.Save(filepath.Join("a", "b", "c.txt"), fileio.SaveOptions{}, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with nested dir failed: %v", err)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(filepath.Join(dir, "a", "b", "c.txt"))
|
||||
if string(got) != "nested" {
|
||||
t.Errorf("file content = %q, want %q", got, "nested")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Save_RejectsTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Save("../../evil.txt", fileio.SaveOptions{}, strings.NewReader("bad"))
|
||||
if err == nil {
|
||||
t.Error("expected error for path traversal in Save")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_Save_RejectsAbsolutePath(t *testing.T) {
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.Save("/tmp/evil.txt", fileio.SaveOptions{}, strings.NewReader("bad"))
|
||||
if err == nil {
|
||||
t.Error("expected error for absolute path in Save")
|
||||
}
|
||||
}
|
||||
|
||||
// ── ResolvePath ──
|
||||
|
||||
func TestLocalFileIO_ResolvePath_ReturnsAbsolute(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
resolved, err := fio.ResolvePath("file.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolvePath failed: %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(resolved) {
|
||||
t.Errorf("expected absolute path, got %q", resolved)
|
||||
}
|
||||
if filepath.Base(resolved) != "file.txt" {
|
||||
t.Errorf("expected base name file.txt, got %q", filepath.Base(resolved))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_ResolvePath_RejectsTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.ResolvePath("../../etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for path traversal in ResolvePath")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalFileIO_ResolvePath_RejectsAbsolute(t *testing.T) {
|
||||
fio := &LocalFileIO{}
|
||||
_, err := fio.ResolvePath("/etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for absolute path in ResolvePath")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Error message consistency ──
|
||||
|
||||
func TestLocalFileIO_ErrorMessages_ContainCorrectFlagName(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
|
||||
// Open/Stat use SafeInputPath → errors should mention "--file"
|
||||
_, err := fio.Open("/absolute/path")
|
||||
if err == nil || !strings.Contains(err.Error(), "--file") {
|
||||
t.Errorf("Open absolute path error should mention --file, got: %v", err)
|
||||
}
|
||||
|
||||
_, err = fio.Stat("/absolute/path")
|
||||
if err == nil || !strings.Contains(err.Error(), "--file") {
|
||||
t.Errorf("Stat absolute path error should mention --file, got: %v", err)
|
||||
}
|
||||
|
||||
// Save/ResolvePath use SafeOutputPath → errors should mention "--output"
|
||||
_, err = fio.Save("/absolute/path", fileio.SaveOptions{}, strings.NewReader(""))
|
||||
if err == nil || !strings.Contains(err.Error(), "--output") {
|
||||
t.Errorf("Save absolute path error should mention --output, got: %v", err)
|
||||
}
|
||||
|
||||
_, err = fio.ResolvePath("/absolute/path")
|
||||
if err == nil || !strings.Contains(err.Error(), "--output") {
|
||||
t.Errorf("ResolvePath absolute path error should mention --output, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Control character / Unicode rejection ──
|
||||
|
||||
func TestLocalFileIO_RejectsControlCharsInPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testChdir(t, dir)
|
||||
|
||||
fio := &LocalFileIO{}
|
||||
paths := []string{
|
||||
"file\x00name.txt", // null byte
|
||||
"file\x1fname.txt", // control char
|
||||
"file\u200Bname.txt", // zero-width space
|
||||
"file\u202Ename.txt", // bidi override
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
if _, err := fio.Open(p); err == nil {
|
||||
t.Errorf("Open(%q) should reject control/dangerous chars", p)
|
||||
}
|
||||
if _, err := fio.Save(p, fileio.SaveOptions{}, strings.NewReader("")); err == nil {
|
||||
t.Errorf("Save(%q) should reject control/dangerous chars", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
245
internal/vfs/localfileio/path_test.go
Normal file
245
internal/vfs/localfileio/path_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package localfileio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSafeOutputPath_RejectsPathTraversalAndDangerousInput(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
// ── GIVEN: normal relative paths → THEN: allowed ──
|
||||
{"normal file", "report.xlsx", false},
|
||||
{"subdir file", "output/report.xlsx", false},
|
||||
{"current dir explicit", "./file.txt", false},
|
||||
{"nested subdir", "a/b/c/file.txt", false},
|
||||
{"dot in name", "my.report.v2.xlsx", false},
|
||||
{"space in name", "my file.txt", false},
|
||||
{"unicode normal", "报告.xlsx", false},
|
||||
{"dot-dot resolves to cwd", "subdir/..", false},
|
||||
|
||||
// ── GIVEN: path traversal via .. → THEN: rejected ──
|
||||
{"dot-dot escape", "../../.ssh/authorized_keys", true},
|
||||
{"dot-dot mid path", "subdir/../../etc/passwd", true},
|
||||
{"triple dot-dot", "../../../etc/shadow", true},
|
||||
|
||||
// ── GIVEN: absolute paths → THEN: rejected ──
|
||||
{"absolute path unix", "/etc/passwd", true},
|
||||
{"absolute path root", "/tmp/evil", true},
|
||||
|
||||
// ── GIVEN: control characters in path → THEN: rejected ──
|
||||
{"null byte", "file\x00.txt", true},
|
||||
{"carriage return", "file\r.txt", true},
|
||||
{"bell char", "file\x07.txt", true},
|
||||
|
||||
// ── GIVEN: dangerous Unicode in path → THEN: rejected ──
|
||||
{"bidi RLO", "file\u202Ename.txt", true},
|
||||
{"zero width space", "file\u200Bname.txt", true},
|
||||
{"BOM char", "file\uFEFFname.txt", true},
|
||||
{"line separator", "file\u2028name.txt", true},
|
||||
{"bidi LRI", "file\u2066name.txt", true},
|
||||
|
||||
// ── GIVEN: looks dangerous but is actually safe → THEN: allowed ──
|
||||
{"literal percent 2e", "%2e%2e/etc/passwd", false},
|
||||
{"tilde path", "~/file.txt", false},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// WHEN: SafeOutputPath validates the path
|
||||
_, err := SafeOutputPath(tt.input)
|
||||
|
||||
// THEN: error matches expectation
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SafeOutputPath(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeOutputPath_ReturnsCanonicalAbsolutePath(t *testing.T) {
|
||||
// GIVEN: a clean temp directory as CWD
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
origDir, _ := os.Getwd()
|
||||
defer os.Chdir(origDir)
|
||||
os.Chdir(dir)
|
||||
|
||||
// WHEN: SafeOutputPath validates a relative path
|
||||
got, err := SafeOutputPath("output/file.txt")
|
||||
|
||||
// THEN: returns the canonical absolute path for subsequent I/O
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := filepath.Join(dir, "output", "file.txt")
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeOutputPath_RejectsSymlinkEscapingCWD(t *testing.T) {
|
||||
// GIVEN: a symlink in CWD pointing to /etc (outside CWD)
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
origDir, _ := os.Getwd()
|
||||
defer os.Chdir(origDir)
|
||||
os.Chdir(dir)
|
||||
os.Symlink("/etc", filepath.Join(dir, "link-to-etc"))
|
||||
|
||||
// WHEN: SafeOutputPath validates a path through the symlink
|
||||
_, err := SafeOutputPath("link-to-etc/passwd")
|
||||
|
||||
// THEN: rejected because the resolved path is outside CWD
|
||||
if err == nil {
|
||||
t.Error("expected error for symlink escaping CWD, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeOutputPath_AllowsSymlinkWithinCWD(t *testing.T) {
|
||||
// GIVEN: a symlink in CWD pointing to a subdirectory within CWD
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
origDir, _ := os.Getwd()
|
||||
defer os.Chdir(origDir)
|
||||
os.Chdir(dir)
|
||||
os.MkdirAll(filepath.Join(dir, "real"), 0755)
|
||||
os.Symlink(filepath.Join(dir, "real"), filepath.Join(dir, "link"))
|
||||
|
||||
// WHEN: SafeOutputPath validates a path through the internal symlink
|
||||
got, err := SafeOutputPath("link/file.txt")
|
||||
|
||||
// THEN: allowed, resolved to the real path within CWD
|
||||
if err != nil {
|
||||
t.Fatalf("symlink within CWD should be allowed: %v", err)
|
||||
}
|
||||
want := filepath.Join(dir, "real", "file.txt")
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeOutputPath_ResolvesAncestorSymlinkWhenParentMissing(t *testing.T) {
|
||||
// GIVEN: CWD contains a symlink "escape" → /etc, and the target path
|
||||
// goes through "escape/sub/file.txt" where "sub" does not exist.
|
||||
// The old code failed to resolve the symlink because the immediate
|
||||
// parent ("escape/sub") didn't exist, leaving resolved un-anchored.
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
origDir, _ := os.Getwd()
|
||||
defer os.Chdir(origDir)
|
||||
os.Chdir(dir)
|
||||
os.Symlink("/etc", filepath.Join(dir, "escape"))
|
||||
|
||||
// WHEN: SafeOutputPath validates a path through the symlink with missing intermediate dirs
|
||||
_, err := SafeOutputPath("escape/nonexistent/file.txt")
|
||||
|
||||
// THEN: rejected — the resolved path is under /etc, outside CWD
|
||||
if err == nil {
|
||||
t.Error("expected error for symlink escaping CWD via non-existent parent, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeOutputPath_DeepNonExistentPathStaysInCWD(t *testing.T) {
|
||||
// GIVEN: a deeply nested non-existent path with no symlinks
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
origDir, _ := os.Getwd()
|
||||
defer os.Chdir(origDir)
|
||||
os.Chdir(dir)
|
||||
|
||||
// WHEN: SafeOutputPath validates "a/b/c/d/file.txt" (none of a/b/c/d exist)
|
||||
got, err := SafeOutputPath("a/b/c/d/file.txt")
|
||||
|
||||
// THEN: allowed, resolved to canonical path under CWD
|
||||
if err != nil {
|
||||
t.Fatalf("deep non-existent path within CWD should be allowed: %v", err)
|
||||
}
|
||||
want := filepath.Join(dir, "a", "b", "c", "d", "file.txt")
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeUploadPath_AllowsTempFileAbsolutePath(t *testing.T) {
|
||||
// GIVEN: a real temp file (absolute path under os.TempDir())
|
||||
f, err := os.CreateTemp("", "upload-test-*.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp: %v", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
f.Close()
|
||||
t.Cleanup(func() { os.Remove(tmpPath) })
|
||||
|
||||
// WHEN: SafeUploadPath validates the absolute temp path
|
||||
_, err = SafeInputPath(tmpPath)
|
||||
|
||||
// THEN: absolute paths are rejected even in temp dir
|
||||
if err == nil {
|
||||
t.Fatal("expected error for absolute temp path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeUploadPath_RejectsNonTempAbsolutePath(t *testing.T) {
|
||||
// GIVEN: an absolute path outside the temp directory
|
||||
// WHEN / THEN: SafeUploadPath rejects it
|
||||
_, err := SafeInputPath("/etc/passwd")
|
||||
if err == nil {
|
||||
t.Error("expected error for absolute non-temp path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeUploadPath_AcceptsRelativePath(t *testing.T) {
|
||||
// GIVEN: a clean temp CWD with a real file
|
||||
dir := t.TempDir()
|
||||
dir, _ = filepath.EvalSymlinks(dir)
|
||||
orig, _ := os.Getwd()
|
||||
defer os.Chdir(orig)
|
||||
os.Chdir(dir)
|
||||
|
||||
os.WriteFile(filepath.Join(dir, "upload.bin"), []byte("data"), 0600)
|
||||
|
||||
// WHEN: SafeUploadPath validates a relative path to an existing file
|
||||
got, err := SafeInputPath("upload.bin")
|
||||
|
||||
// THEN: accepted and returned as absolute canonical path
|
||||
if err != nil {
|
||||
t.Fatalf("SafeUploadPath(relative) error = %v", err)
|
||||
}
|
||||
want := filepath.Join(dir, "upload.bin")
|
||||
if got != want {
|
||||
t.Errorf("SafeUploadPath(relative) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SafeInputPath_ErrorMessageContainsCorrectFlagName(t *testing.T) {
|
||||
// GIVEN: an absolute path
|
||||
|
||||
// WHEN: SafeInputPath rejects it
|
||||
_, err := SafeInputPath("/etc/passwd")
|
||||
|
||||
// THEN: error message mentions --file (not --output)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for absolute path")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--file") {
|
||||
t.Errorf("error should mention --file, got: %s", err.Error())
|
||||
}
|
||||
|
||||
// WHEN: SafeOutputPath rejects it
|
||||
_, err = SafeOutputPath("/etc/passwd")
|
||||
|
||||
// THEN: error message mentions --output (not --file)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for absolute path")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--output") {
|
||||
t.Errorf("error should mention --output, got: %s", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -25,8 +25,6 @@ import (
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -335,6 +333,39 @@ func (ctx *RuntimeContext) ResolveSavePath(path string) (string, error) {
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// WrapSaveError matches a FileIO.Save error against known categories and wraps
|
||||
// it with the caller-provided message prefix, preserving backward-compatible
|
||||
// error text per shortcut.
|
||||
func WrapSaveError(err error, pathMsg, mkdirMsg, writeMsg string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var me *fileio.MkdirError
|
||||
var we *fileio.WriteError
|
||||
switch {
|
||||
case errors.Is(err, fileio.ErrPathValidation):
|
||||
return fmt.Errorf("%s: %w", pathMsg, err)
|
||||
case errors.As(err, &me):
|
||||
return fmt.Errorf("%s: %w", mkdirMsg, err)
|
||||
case errors.As(err, &we):
|
||||
return fmt.Errorf("%s: %w", writeMsg, err)
|
||||
default:
|
||||
return fmt.Errorf("%s: %w", writeMsg, err)
|
||||
}
|
||||
}
|
||||
|
||||
// WrapOpenError matches a FileIO.Open/Stat error and wraps it with the
|
||||
// caller-provided message prefix.
|
||||
func WrapOpenError(err error, pathMsg, readMsg string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, fileio.ErrPathValidation) {
|
||||
return fmt.Errorf("%s: %w", pathMsg, err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w", readMsg, err)
|
||||
}
|
||||
|
||||
// ValidatePath checks that path is a valid relative input path within the
|
||||
// working directory by delegating to FileIO.Stat. Returns nil if the path is
|
||||
// valid or does not exist yet; returns an error only for illegal paths
|
||||
@@ -634,11 +665,15 @@ func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error {
|
||||
if path == "" {
|
||||
return FlagErrorf("--%s: file path cannot be empty after @", fl.Name)
|
||||
}
|
||||
safePath, err := validate.SafeInputPath(path)
|
||||
f, err := rctx.FileIO().Open(path)
|
||||
if err != nil {
|
||||
return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err)
|
||||
if errors.Is(err, fileio.ErrPathValidation) {
|
||||
return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err)
|
||||
}
|
||||
return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err)
|
||||
}
|
||||
data, err := vfs.ReadFile(safePath)
|
||||
data, err := io.ReadAll(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
_ "github.com/larksuite/cli/internal/vfs/localfileio"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
@@ -102,6 +103,48 @@ func TestRuntimeContext_Out_WithJq_InvalidExpr_WritesStderr(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type testResolvedFileIO struct{}
|
||||
|
||||
func (testResolvedFileIO) Open(string) (fileio.File, error) { return nil, nil }
|
||||
func (testResolvedFileIO) Stat(string) (fileio.FileInfo, error) { return nil, nil }
|
||||
func (testResolvedFileIO) ResolvePath(path string) (string, error) { return path, nil }
|
||||
func (testResolvedFileIO) Save(string, fileio.SaveOptions, io.Reader) (fileio.SaveResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type capturingFileIOProvider struct {
|
||||
gotCtx context.Context
|
||||
fileIO fileio.FileIO
|
||||
}
|
||||
|
||||
func (p *capturingFileIOProvider) Name() string { return "capture" }
|
||||
|
||||
func (p *capturingFileIOProvider) ResolveFileIO(ctx context.Context) fileio.FileIO {
|
||||
p.gotCtx = ctx
|
||||
return p.fileIO
|
||||
}
|
||||
|
||||
func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) {
|
||||
execCtx := context.WithValue(context.Background(), "key", "value")
|
||||
resolved := testResolvedFileIO{}
|
||||
provider := &capturingFileIOProvider{fileIO: resolved}
|
||||
|
||||
rctx := &RuntimeContext{
|
||||
ctx: execCtx,
|
||||
Factory: &cmdutil.Factory{
|
||||
FileIOProvider: provider,
|
||||
},
|
||||
}
|
||||
|
||||
got := rctx.FileIO()
|
||||
if got != resolved {
|
||||
t.Fatalf("FileIO() returned %T, want %T", got, resolved)
|
||||
}
|
||||
if provider.gotCtx != execCtx {
|
||||
t.Fatal("ResolveFileIO() did not receive the runtime execution context")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestShortcutCmd(s *Shortcut) *cobra.Command {
|
||||
cmd := &cobra.Command{Use: "test-shortcut"}
|
||||
cmd.SetContext(context.Background())
|
||||
@@ -119,7 +162,8 @@ func newTestFactory() *cmdutil.Factory {
|
||||
LarkClient: func() (*lark.Client, error) {
|
||||
return lark.NewClient("test", "test"), nil
|
||||
},
|
||||
IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}},
|
||||
IOStreams: &cmdutil.IOStreams{Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}},
|
||||
FileIOProvider: fileio.GetProvider(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
)
|
||||
|
||||
func TestSanitizeURLForDisplay(t *testing.T) {
|
||||
@@ -404,39 +406,36 @@ func TestBuildSearchChatBodyAdditionalBranches(t *testing.T) {
|
||||
|
||||
func TestParseMediaDurationSuccess(t *testing.T) {
|
||||
t.Run("mp4", func(t *testing.T) {
|
||||
f, err := os.CreateTemp("", "im-duration-*.mp4")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp() error = %v", err)
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
fname := "im-duration-test.mp4"
|
||||
if err := os.WriteFile(fname, wrapInMoov(buildMvhdBox(0, 1000, 5000)), 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Write(wrapInMoov(buildMvhdBox(0, 1000, 5000))); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
if got := parseMediaDuration(f.Name(), "mp4"); got != "5000" {
|
||||
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected")
|
||||
}))
|
||||
if got := parseMediaDuration(rt, fname, "mp4"); got != "5000" {
|
||||
t.Fatalf("parseMediaDuration(mp4) = %q, want %q", got, "5000")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("opus", func(t *testing.T) {
|
||||
f, err := os.CreateTemp("", "im-duration-*.ogg")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp() error = %v", err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
page := make([]byte, 27)
|
||||
copy(page[0:4], "OggS")
|
||||
page[5] = 4
|
||||
page[6] = 0x00
|
||||
page[7] = 0x53
|
||||
page[8] = 0x07
|
||||
if _, err := f.Write(page); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
|
||||
fname := "im-duration-test.ogg"
|
||||
if err := os.WriteFile(fname, page, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
if got := parseMediaDuration(f.Name(), "opus"); got != "10000" {
|
||||
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected")
|
||||
}))
|
||||
if got := parseMediaDuration(rt, fname, "opus"); got != "10000" {
|
||||
t.Fatalf("parseMediaDuration(opus) = %q, want %q", got, "10000")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -13,16 +13,15 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -327,21 +326,16 @@ func resolveURLMedia(ctx context.Context, runtime *common.RuntimeContext, s medi
|
||||
func resolveLocalMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) {
|
||||
fmt.Fprintf(runtime.IO().ErrOut, "uploading %s: %s\n", s.mediaType, filepath.Base(s.value))
|
||||
|
||||
safePath, err := validate.SafeInputPath(s.value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s.kind == mediaKindImage {
|
||||
return uploadImageToIM(ctx, runtime, safePath, "message")
|
||||
return uploadImageToIM(ctx, runtime, s.value, "message")
|
||||
}
|
||||
|
||||
ft := detectIMFileType(safePath)
|
||||
ft := detectIMFileType(s.value)
|
||||
dur := ""
|
||||
if s.withDuration {
|
||||
dur = parseMediaDuration(safePath, ft)
|
||||
dur = parseMediaDuration(runtime, s.value, ft)
|
||||
}
|
||||
return uploadFileToIM(ctx, runtime, safePath, ft, dur)
|
||||
return uploadFileToIM(ctx, runtime, s.value, ft, dur)
|
||||
}
|
||||
|
||||
// resolveVideoContent handles the video case which needs both a file_key and
|
||||
@@ -556,18 +550,16 @@ func findMP4Box(data []byte, start, end int, boxType string) (int, int) {
|
||||
// for audio/video uploads. Only reads the minimal portion of the file needed
|
||||
// for parsing (tail for OGG, box headers + moov for MP4).
|
||||
// Returns "" if parsing fails or the file type is not audio/video.
|
||||
func parseMediaDuration(filePath, fileType string) string {
|
||||
func parseMediaDuration(runtime *common.RuntimeContext, filePath, fileType string) string {
|
||||
if fileType != "opus" && fileType != "mp4" {
|
||||
return ""
|
||||
}
|
||||
f, err := vfs.Open(filePath)
|
||||
if err != nil {
|
||||
info, err := runtime.FileIO().Stat(filePath)
|
||||
if err != nil || info.Size() == 0 {
|
||||
return ""
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil || info.Size() == 0 {
|
||||
f, err := runtime.FileIO().Open(filePath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -698,7 +690,7 @@ func readMp4DurationBytes(data []byte) int64 {
|
||||
}
|
||||
|
||||
// readOggDuration reads the tail of an OGG file (up to 64 KB) and parses duration.
|
||||
func readOggDuration(f *os.File, fileSize int64) int64 {
|
||||
func readOggDuration(f fileio.File, fileSize int64) int64 {
|
||||
const maxTail = 65536
|
||||
readSize := fileSize
|
||||
if readSize > maxTail {
|
||||
@@ -713,7 +705,7 @@ func readOggDuration(f *os.File, fileSize int64) int64 {
|
||||
|
||||
// readMp4Duration walks top-level MP4 boxes via file seeks to find moov,
|
||||
// then reads only the moov content to locate mvhd and extract the duration.
|
||||
func readMp4Duration(f *os.File, fileSize int64) int64 {
|
||||
func readMp4Duration(f fileio.File, fileSize int64) int64 {
|
||||
hdr := make([]byte, 16)
|
||||
var offset int64
|
||||
for offset+8 <= fileSize {
|
||||
@@ -1005,14 +997,11 @@ const maxImageUploadSize = 5 * 1024 * 1024 // 5MB — Lark API limit for images
|
||||
const maxFileUploadSize = 100 * 1024 * 1024 // 100MB — Lark API limit for files
|
||||
|
||||
func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, imageType string) (string, error) {
|
||||
// filePath is already validated by the caller (resolveLocalMedia).
|
||||
safePath := filePath
|
||||
|
||||
if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxImageUploadSize {
|
||||
if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxImageUploadSize {
|
||||
return "", fmt.Errorf("image size %s exceeds limit (max 5MB)", common.FormatSize(info.Size()))
|
||||
}
|
||||
|
||||
f, err := vfs.Open(safePath)
|
||||
f, err := runtime.FileIO().Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1045,14 +1034,11 @@ func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePa
|
||||
}
|
||||
|
||||
func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, fileType, duration string) (string, error) {
|
||||
// filePath is already validated by the caller (resolveLocalMedia).
|
||||
safePath := filePath
|
||||
|
||||
if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxFileUploadSize {
|
||||
if info, err := runtime.FileIO().Stat(filePath); err == nil && info.Size() > maxFileUploadSize {
|
||||
return "", fmt.Errorf("file size %s exceeds limit (max 100MB)", common.FormatSize(info.Size()))
|
||||
}
|
||||
|
||||
f, err := vfs.Open(safePath)
|
||||
f, err := runtime.FileIO().Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1060,7 +1046,7 @@ func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePat
|
||||
|
||||
fd := larkcore.NewFormdata()
|
||||
fd.AddField("file_type", fileType)
|
||||
fd.AddField("file_name", filepath.Base(safePath))
|
||||
fd.AddField("file_name", filepath.Base(filePath))
|
||||
if duration != "" {
|
||||
fd.AddField("duration", duration)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/cmdutil"
|
||||
"github.com/larksuite/cli/internal/core"
|
||||
"github.com/larksuite/cli/internal/credential"
|
||||
@@ -52,9 +53,10 @@ func shortcutRawResponse(status int, body []byte, headers http.Header) *http.Res
|
||||
headers = make(http.Header)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: headers,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
StatusCode: status,
|
||||
Header: headers,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,10 +90,11 @@ func newBotShortcutRuntime(t *testing.T, rt http.RoundTripper) *common.RuntimeCo
|
||||
runtime := &common.RuntimeContext{
|
||||
Config: cfg,
|
||||
Factory: &cmdutil.Factory{
|
||||
Config: func() (*core.CliConfig, error) { return cfg, nil },
|
||||
HttpClient: func() (*http.Client, error) { return httpClient, nil },
|
||||
LarkClient: func() (*lark.Client, error) { return sdk, nil },
|
||||
Credential: testCred,
|
||||
Config: func() (*core.CliConfig, error) { return cfg, nil },
|
||||
HttpClient: func() (*http.Client, error) { return httpClient, nil },
|
||||
LarkClient: func() (*lark.Client, error) { return sdk, nil },
|
||||
Credential: testCred,
|
||||
FileIOProvider: fileio.GetProvider(),
|
||||
IOStreams: &cmdutil.IOStreams{
|
||||
Out: &bytes.Buffer{},
|
||||
ErrOut: &bytes.Buffer{},
|
||||
@@ -241,7 +244,9 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
target := filepath.Join(t.TempDir(), "nested", "resource.bin")
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
target := filepath.Join("nested", "resource.bin")
|
||||
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", target)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
@@ -280,7 +285,9 @@ func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", filepath.Join(t.TempDir(), "out.bin"))
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_403", "file_403", "file", "out.bin")
|
||||
if err == nil || !strings.Contains(err.Error(), "HTTP 403: denied") {
|
||||
t.Fatalf("downloadIMResourceToPath() error = %v", err)
|
||||
}
|
||||
@@ -305,28 +312,14 @@ func TestUploadImageToIMSuccess(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
t.Fatalf("Chdir() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chdir(wd)
|
||||
})
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
path := "demo.png"
|
||||
if err := os.WriteFile(path, []byte("png"), 0600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Abs() error = %v", err)
|
||||
}
|
||||
got, err := uploadImageToIM(context.Background(), runtime, absPath, "message")
|
||||
got, err := uploadImageToIM(context.Background(), runtime, path, "message")
|
||||
if err != nil {
|
||||
t.Fatalf("uploadImageToIM() error = %v", err)
|
||||
}
|
||||
@@ -357,28 +350,14 @@ func TestUploadFileToIMSuccess(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
t.Fatalf("Chdir() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chdir(wd)
|
||||
})
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
path := "demo.txt"
|
||||
if err := os.WriteFile(path, []byte("demo"), 0600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Abs() error = %v", err)
|
||||
}
|
||||
got, err := uploadFileToIM(context.Background(), runtime, absPath, "stream", "1200")
|
||||
got, err := uploadFileToIM(context.Background(), runtime, path, "stream", "1200")
|
||||
if err != nil {
|
||||
t.Fatalf("uploadFileToIM() error = %v", err)
|
||||
}
|
||||
@@ -394,7 +373,8 @@ func TestUploadFileToIMSuccess(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUploadImageToIMSizeLimit(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "too-large.png")
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
path := "too-large.png"
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
@@ -404,14 +384,18 @@ func TestUploadImageToIMSizeLimit(t *testing.T) {
|
||||
}
|
||||
f.Close()
|
||||
|
||||
_, err = uploadImageToIM(context.Background(), nil, path, "message")
|
||||
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected")
|
||||
}))
|
||||
_, err = uploadImageToIM(context.Background(), rt, path, "message")
|
||||
if err == nil || !strings.Contains(err.Error(), "exceeds limit") {
|
||||
t.Fatalf("uploadImageToIM() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadFileToIMSizeLimit(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "too-large.bin")
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
path := "too-large.bin"
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
@@ -421,7 +405,10 @@ func TestUploadFileToIMSizeLimit(t *testing.T) {
|
||||
}
|
||||
f.Close()
|
||||
|
||||
_, err = uploadFileToIM(context.Background(), nil, path, "stream", "")
|
||||
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected")
|
||||
}))
|
||||
_, err = uploadFileToIM(context.Background(), rt, path, "stream", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "exceeds limit") {
|
||||
t.Fatalf("uploadFileToIM() error = %v", err)
|
||||
}
|
||||
@@ -430,6 +417,7 @@ func TestUploadFileToIMSizeLimit(t *testing.T) {
|
||||
func TestResolveMediaContentWrapsUploadError(t *testing.T) {
|
||||
runtime := &common.RuntimeContext{
|
||||
Factory: &cmdutil.Factory{
|
||||
FileIOProvider: fileio.GetProvider(),
|
||||
IOStreams: &cmdutil.IOStreams{
|
||||
Out: &bytes.Buffer{},
|
||||
ErrOut: &bytes.Buffer{},
|
||||
@@ -437,7 +425,9 @@ func TestResolveMediaContentWrapsUploadError(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
missing := filepath.Join(t.TempDir(), "missing.png")
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
missing := "missing.png"
|
||||
_, _, err := resolveMediaContent(context.Background(), runtime, "", missing, "", "", "", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "image upload failed") {
|
||||
t.Fatalf("resolveMediaContent() error = %v", err)
|
||||
@@ -457,15 +447,7 @@ func TestResolveLocalMediaImage(t *testing.T) {
|
||||
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
|
||||
}))
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
t.Fatalf("Chdir() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Chdir(wd) })
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
if err := os.WriteFile("test.png", []byte("png-data"), 0600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
@@ -496,15 +478,7 @@ func TestResolveLocalMediaFile(t *testing.T) {
|
||||
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
|
||||
}))
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Getwd() error = %v", err)
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
t.Fatalf("Chdir() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Chdir(wd) })
|
||||
cmdutil.TestChdir(t, t.TempDir())
|
||||
|
||||
if err := os.WriteFile("test.txt", []byte("file-data"), 0600); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -263,10 +264,13 @@ func TestParseMp4Duration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseMediaDuration(t *testing.T) {
|
||||
if got := parseMediaDuration("test.pdf", "pdf"); got != "" {
|
||||
rt := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected")
|
||||
}))
|
||||
if got := parseMediaDuration(rt, "test.pdf", "pdf"); got != "" {
|
||||
t.Fatalf("parseMediaDuration(pdf) = %q, want empty", got)
|
||||
}
|
||||
if got := parseMediaDuration("nonexistent.opus", "opus"); got != "" {
|
||||
if got := parseMediaDuration(rt, "nonexistent.opus", "opus"); got != "" {
|
||||
t.Fatalf("parseMediaDuration(missing) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,29 +91,13 @@ var ImMessagesReply = common.Shortcut{
|
||||
videoCoverKey := runtime.Str("video-cover")
|
||||
audioKey := runtime.Str("audio")
|
||||
|
||||
if !isMediaKey(imageKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(fileKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoCoverKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(audioKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
fio := runtime.FileIO()
|
||||
for _, mf := range []struct{ flag, val string }{
|
||||
{"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey},
|
||||
{"--video-cover", videoCoverKey}, {"--audio", audioKey},
|
||||
} {
|
||||
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,29 +133,13 @@ var ImMessagesReply = common.Shortcut{
|
||||
audioVal := runtime.Str("audio")
|
||||
replyInThread := runtime.Bool("reply-in-thread")
|
||||
idempotencyKey := runtime.Str("idempotency-key")
|
||||
if !isMediaKey(imageVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(fileVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoCoverVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(audioVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
fio := runtime.FileIO()
|
||||
for _, mf := range []struct{ flag, val string }{
|
||||
{"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal},
|
||||
{"--video-cover", videoCoverVal}, {"--audio", audioVal},
|
||||
} {
|
||||
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,15 @@ package im
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/client"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/internal/vfs"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
)
|
||||
@@ -54,7 +54,7 @@ var ImMessagesResourcesDownload = common.Shortcut{
|
||||
if err != nil {
|
||||
return output.ErrValidation("%s", err)
|
||||
}
|
||||
if _, err := validate.SafeOutputPath(relPath); err != nil {
|
||||
if _, err := runtime.ResolveSavePath(relPath); err != nil {
|
||||
return output.ErrValidation("unsafe output path: %s", err)
|
||||
}
|
||||
return nil
|
||||
@@ -67,12 +67,8 @@ var ImMessagesResourcesDownload = common.Shortcut{
|
||||
if err != nil {
|
||||
return output.ErrValidation("invalid output path: %s", err)
|
||||
}
|
||||
safePath, err := validate.SafeOutputPath(relPath)
|
||||
if err != nil {
|
||||
return output.ErrValidation("unsafe output path: %s", err)
|
||||
}
|
||||
|
||||
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, safePath)
|
||||
finalPath, sizeBytes, err := downloadIMResourceToPath(ctx, runtime, messageId, fileKey, fileType, relPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -109,33 +105,33 @@ func normalizeDownloadOutputPath(fileKey, outputPath string) (string, error) {
|
||||
const defaultIMResourceDownloadTimeout = 120 * time.Second
|
||||
|
||||
var imMimeToExt = map[string]string{
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/svg+xml": ".svg",
|
||||
"application/pdf": ".pdf",
|
||||
"video/mp4": ".mp4",
|
||||
"video/3gpp": ".3gp",
|
||||
"video/x-msvideo": ".avi",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/wav": ".wav",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/css": ".css",
|
||||
"text/csv": ".csv",
|
||||
"application/zip": ".zip",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/svg+xml": ".svg",
|
||||
"application/pdf": ".pdf",
|
||||
"video/mp4": ".mp4",
|
||||
"video/3gpp": ".3gp",
|
||||
"video/x-msvideo": ".avi",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/wav": ".wav",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/css": ".css",
|
||||
"text/csv": ".csv",
|
||||
"application/zip": ".zip",
|
||||
"application/x-zip-compressed": ".zip",
|
||||
"application/x-rar-compressed": ".rar",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/octet-stream": ".bin",
|
||||
"application/msword": ".doc",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/octet-stream": ".bin",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
}
|
||||
|
||||
@@ -156,8 +152,12 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex
|
||||
}
|
||||
defer downloadResp.Body.Close()
|
||||
|
||||
if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil {
|
||||
return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err)
|
||||
if downloadResp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(downloadResp.Body, 4096))
|
||||
if len(body) > 0 {
|
||||
return "", 0, output.ErrNetwork("download failed: HTTP %d: %s", downloadResp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
return "", 0, output.ErrNetwork("download failed: HTTP %d", downloadResp.StatusCode)
|
||||
}
|
||||
|
||||
// Auto-detect extension from Content-Type if missing
|
||||
@@ -171,9 +171,19 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex
|
||||
}
|
||||
}
|
||||
|
||||
sizeBytes, err := validate.AtomicWriteFromReader(finalPath, downloadResp.Body, 0600)
|
||||
result, err := runtime.FileIO().Save(finalPath, fileio.SaveOptions{
|
||||
ContentType: downloadResp.Header.Get("Content-Type"),
|
||||
ContentLength: downloadResp.ContentLength,
|
||||
}, downloadResp.Body)
|
||||
if err != nil {
|
||||
return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create file: %s", err)
|
||||
return "", 0, output.Errorf(output.ExitInternal, "api_error", "%s",
|
||||
common.WrapSaveError(err, "unsafe output path", "cannot create parent directory", "cannot create file"))
|
||||
}
|
||||
return finalPath, sizeBytes, nil
|
||||
savedPath, resolveErr := runtime.ResolveSavePath(finalPath)
|
||||
if resolveErr != nil {
|
||||
// Save succeeded — file is on disk. Fall back to the relative path
|
||||
// rather than returning an error for a successfully written file.
|
||||
savedPath = finalPath
|
||||
}
|
||||
return savedPath, result.Size(), nil
|
||||
}
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/larksuite/cli/extension/fileio"
|
||||
"github.com/larksuite/cli/internal/output"
|
||||
"github.com/larksuite/cli/internal/validate"
|
||||
"github.com/larksuite/cli/shortcuts/common"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
)
|
||||
@@ -98,29 +99,13 @@ var ImMessagesSend = common.Shortcut{
|
||||
videoCoverKey := runtime.Str("video-cover")
|
||||
audioKey := runtime.Str("audio")
|
||||
|
||||
if !isMediaKey(imageKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--image", imageKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(fileKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--file", fileKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video", videoKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoCoverKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(audioKey) {
|
||||
if _, err := validate.SafeLocalFlagPath("--audio", audioKey); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
fio := runtime.FileIO()
|
||||
for _, mf := range []struct{ flag, val string }{
|
||||
{"--image", imageKey}, {"--file", fileKey}, {"--video", videoKey},
|
||||
{"--video-cover", videoCoverKey}, {"--audio", audioKey},
|
||||
} {
|
||||
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,29 +150,13 @@ var ImMessagesSend = common.Shortcut{
|
||||
videoVal := runtime.Str("video")
|
||||
videoCoverVal := runtime.Str("video-cover")
|
||||
audioVal := runtime.Str("audio")
|
||||
if !isMediaKey(imageVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--image", imageVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(fileVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--file", fileVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video", videoVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(videoCoverVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--video-cover", videoCoverVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
}
|
||||
}
|
||||
if !isMediaKey(audioVal) {
|
||||
if _, err := validate.SafeLocalFlagPath("--audio", audioVal); err != nil {
|
||||
return output.ErrValidation("%v", err)
|
||||
fio := runtime.FileIO()
|
||||
for _, mf := range []struct{ flag, val string }{
|
||||
{"--image", imageVal}, {"--file", fileVal}, {"--video", videoVal},
|
||||
{"--video-cover", videoCoverVal}, {"--audio", audioVal},
|
||||
} {
|
||||
if err := validateMediaFlagPath(fio, mf.flag, mf.val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Resolve content type
|
||||
@@ -239,3 +208,15 @@ var ImMessagesSend = common.Shortcut{
|
||||
func isMediaKey(value string) bool {
|
||||
return strings.HasPrefix(value, "img_") || strings.HasPrefix(value, "file_")
|
||||
}
|
||||
|
||||
// validateMediaFlagPath validates a media flag value as a local file path via FileIO.
|
||||
// Empty values, URLs, and media keys are skipped (not local files).
|
||||
func validateMediaFlagPath(fio fileio.FileIO, flagName, value string) error {
|
||||
if value == "" || strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") || isMediaKey(value) {
|
||||
return nil
|
||||
}
|
||||
if _, err := fio.Stat(value); err != nil && !os.IsNotExist(err) {
|
||||
return output.ErrValidation("%s: %v", flagName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
51
shortcuts/im/validate_media_test.go
Normal file
51
shortcuts/im/validate_media_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package im
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/larksuite/cli/internal/vfs/localfileio"
|
||||
)
|
||||
|
||||
func TestValidateMediaFlagPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
orig, _ := os.Getwd()
|
||||
defer os.Chdir(orig)
|
||||
os.Chdir(dir)
|
||||
os.WriteFile(filepath.Join(dir, "photo.jpg"), []byte("img"), 0644)
|
||||
|
||||
fio := &localfileio.LocalFileIO{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
flag string
|
||||
value string
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty value skipped", "--image", "", false},
|
||||
{"http URL skipped", "--image", "http://example.com/a.jpg", false},
|
||||
{"https URL skipped", "--file", "https://example.com/b.mp4", false},
|
||||
{"media key skipped", "--image", "img_abc123", false},
|
||||
{"file key skipped", "--file", "file_abc123", false},
|
||||
{"valid local file", "--image", "photo.jpg", false},
|
||||
{"nonexistent file allowed", "--file", "missing.txt", false},
|
||||
{"path traversal rejected", "--image", "../../etc/passwd", true},
|
||||
{"absolute path rejected", "--file", "/etc/passwd", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateMediaFlagPath(fio, tt.flag, tt.value)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Fatalf("expected error for %s=%q, got nil", tt.flag, tt.value)
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Fatalf("unexpected error for %s=%q: %v", tt.flag, tt.value, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user