mirror of
https://github.com/larksuite/cli.git
synced 2026-07-03 14:02:43 +08:00
feat: add darwin file master key fallback for keychain writes (#285)
* feat: (MacOS) add fallback file-based master key storage * refactor(keychain): improve master key file handling and corruption checks - Replace temporary file approach with direct file creation - Add explicit corruption checks for existing keys - Ensure atomic operations and proper cleanup on failure * docs(keychain): add comments to clarify constants and variables Add descriptive comments to explain the purpose of timeout, crypto parameters, and test variables in the macOS keychain implementation. * fix(keychain): use atomic write for master key initialization * fix(keychain): add retry logic for reading master key file Add retry mechanism when reading existing master key file to handle potential race conditions. Return early if read error occurs instead of waiting for all retries. * refactor(keychain): simplify master key validation logic Restructure the key validation flow to reduce redundant checks and improve readability. The corrupted key check is moved after the error handling block for better logical flow. * refactor(keychain): replace os package with vfs for file operations Use vfs package instead of os for file operations to improve testability and abstract filesystem access. This change makes it easier to mock filesystem operations in tests and provides a consistent interface for file handling.
This commit is contained in:
@@ -22,11 +22,27 @@ import (
|
||||
"github.com/zalando/go-keyring"
|
||||
)
|
||||
|
||||
// keychainTimeout bounds system keychain access to avoid hanging on blocked prompts.
|
||||
const keychainTimeout = 5 * time.Second
|
||||
|
||||
// masterKeyBytes is the AES-256 key size used to encrypt stored secrets.
|
||||
const masterKeyBytes = 32
|
||||
|
||||
// ivBytes is the nonce size used by AES-GCM.
|
||||
const ivBytes = 12
|
||||
|
||||
// tagBytes is the authentication tag size produced by AES-GCM.
|
||||
const tagBytes = 16
|
||||
|
||||
// fileMasterKeyName is the local fallback master key file name.
|
||||
const fileMasterKeyName = "master.key.file"
|
||||
|
||||
// keyringGet is overridden in tests to simulate system keychain reads.
|
||||
var keyringGet = keyring.Get
|
||||
|
||||
// keyringSet is overridden in tests to simulate system keychain writes.
|
||||
var keyringSet = keyring.Set
|
||||
|
||||
// StorageDir returns the storage directory for a given service name on macOS.
|
||||
func StorageDir(service string) string {
|
||||
home, err := vfs.UserHomeDir()
|
||||
@@ -57,7 +73,7 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) {
|
||||
go func() {
|
||||
defer func() { recover() }()
|
||||
|
||||
encodedKey, err := keyring.Get(service, "master.key")
|
||||
encodedKey, err := keyringGet(service, "master.key")
|
||||
if err == nil {
|
||||
key, decodeErr := base64.StdEncoding.DecodeString(encodedKey)
|
||||
if decodeErr == nil && len(key) == masterKeyBytes {
|
||||
@@ -88,7 +104,7 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) {
|
||||
}
|
||||
|
||||
encodedKeyStr := base64.StdEncoding.EncodeToString(key)
|
||||
setErr := keyring.Set(service, "master.key", encodedKeyStr)
|
||||
setErr := keyringSet(service, "master.key", encodedKeyStr)
|
||||
if setErr != nil {
|
||||
resCh <- result{key: nil, err: setErr}
|
||||
return
|
||||
@@ -105,6 +121,85 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getFileMasterKey retrieves the fallback master key from local storage.
|
||||
// If allowCreate is true, it generates and stores a new fallback master key when missing.
|
||||
func getFileMasterKey(service string, allowCreate bool) ([]byte, error) {
|
||||
dir := StorageDir(service)
|
||||
keyPath := filepath.Join(dir, fileMasterKeyName)
|
||||
|
||||
key, err := vfs.ReadFile(keyPath)
|
||||
if err == nil && len(key) == masterKeyBytes {
|
||||
return key, nil
|
||||
}
|
||||
if err == nil && len(key) != masterKeyBytes {
|
||||
return nil, errors.New("keychain is corrupted")
|
||||
}
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
if !allowCreate {
|
||||
return nil, errNotInitialized
|
||||
}
|
||||
if err := vfs.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key = make([]byte, masterKeyBytes)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := vfs.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrExist) {
|
||||
for i := 0; i < 3; i++ {
|
||||
existingKey, readErr := vfs.ReadFile(keyPath)
|
||||
if readErr == nil && len(existingKey) == masterKeyBytes {
|
||||
return existingKey, nil
|
||||
}
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
if i < 2 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return nil, errors.New("keychain is corrupted")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
writeFailed := true
|
||||
defer func() {
|
||||
if writeFailed {
|
||||
_ = vfs.Remove(keyPath)
|
||||
}
|
||||
}()
|
||||
if _, err := file.Write(key); err != nil {
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeFailed = false
|
||||
|
||||
canonicalKey, err := vfs.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
existingKey, readErr := vfs.ReadFile(keyPath)
|
||||
if readErr == nil && len(existingKey) == masterKeyBytes {
|
||||
return existingKey, nil
|
||||
}
|
||||
if readErr == nil && len(existingKey) != masterKeyBytes {
|
||||
return nil, errors.New("keychain is corrupted")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if len(canonicalKey) != masterKeyBytes {
|
||||
return nil, errors.New("keychain is corrupted")
|
||||
}
|
||||
return canonicalKey, nil
|
||||
}
|
||||
|
||||
// encryptData encrypts data using AES-GCM.
|
||||
func encryptData(plaintext string, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
@@ -161,6 +256,11 @@ func platformGet(service, account string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if key, ferr := getFileMasterKey(service, false); ferr == nil {
|
||||
if plaintext, derr := decryptData(data, key); derr == nil {
|
||||
return plaintext, nil
|
||||
}
|
||||
}
|
||||
key, err := getMasterKey(service, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -174,9 +274,15 @@ func platformGet(service, account string) (string, error) {
|
||||
|
||||
// platformSet stores a value in the macOS keychain.
|
||||
func platformSet(service, account, data string) error {
|
||||
key, err := getMasterKey(service, true)
|
||||
key, err := getFileMasterKey(service, false)
|
||||
if err != nil {
|
||||
return err
|
||||
key, err = getMasterKey(service, true)
|
||||
if err != nil {
|
||||
key, err = getFileMasterKey(service, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
dir := StorageDir(service)
|
||||
if err := vfs.MkdirAll(dir, 0700); err != nil {
|
||||
|
||||
160
internal/keychain/keychain_darwin_test.go
Normal file
160
internal/keychain/keychain_darwin_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
//go:build darwin
|
||||
|
||||
package keychain
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/zalando/go-keyring"
|
||||
)
|
||||
|
||||
// TestPlatformSetFallsBackToFileMasterKey verifies writes fall back to a file master key
|
||||
// when the system keychain cannot create the master key.
|
||||
func TestPlatformSetFallsBackToFileMasterKey(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
|
||||
origGet := keyringGet
|
||||
origSet := keyringSet
|
||||
keyringGet = func(service, user string) (string, error) {
|
||||
return "", keyring.ErrNotFound
|
||||
}
|
||||
keyringSet = func(service, user, password string) error {
|
||||
return errors.New("blocked")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
keyringGet = origGet
|
||||
keyringSet = origSet
|
||||
})
|
||||
|
||||
service := "test-service"
|
||||
account := "test-account"
|
||||
secret := "secret-value"
|
||||
|
||||
if err := platformSet(service, account, secret); err != nil {
|
||||
t.Fatalf("platformSet() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(StorageDir(service), fileMasterKeyName)); err != nil {
|
||||
t.Fatalf("file master key not created: %v", err)
|
||||
}
|
||||
|
||||
got, err := platformGet(service, account)
|
||||
if err != nil {
|
||||
t.Fatalf("platformGet() error = %v", err)
|
||||
}
|
||||
if got != secret {
|
||||
t.Fatalf("platformGet() = %q, want %q", got, secret)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlatformGetPrefersFileMasterKey verifies reads prefer the file-based master key
|
||||
// before trying the system keychain master key.
|
||||
func TestPlatformGetPrefersFileMasterKey(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
|
||||
fileKey := make([]byte, masterKeyBytes)
|
||||
for i := range fileKey {
|
||||
fileKey[i] = byte(i + 1)
|
||||
}
|
||||
keychainKey := make([]byte, masterKeyBytes)
|
||||
for i := range keychainKey {
|
||||
keychainKey[i] = byte(i + 33)
|
||||
}
|
||||
|
||||
origGet := keyringGet
|
||||
origSet := keyringSet
|
||||
keyringGet = func(service, user string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString(keychainKey), nil
|
||||
}
|
||||
keyringSet = func(service, user, password string) error {
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
keyringGet = origGet
|
||||
keyringSet = origSet
|
||||
})
|
||||
|
||||
service := "test-service"
|
||||
account := "test-account"
|
||||
secret := "secret-value"
|
||||
|
||||
dir := StorageDir(service)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, fileMasterKeyName), fileKey, 0600); err != nil {
|
||||
t.Fatalf("WriteFile(master key) error = %v", err)
|
||||
}
|
||||
encrypted, err := encryptData(secret, fileKey)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptData() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, safeFileName(account)), encrypted, 0600); err != nil {
|
||||
t.Fatalf("WriteFile(secret) error = %v", err)
|
||||
}
|
||||
|
||||
got, err := platformGet(service, account)
|
||||
if err != nil {
|
||||
t.Fatalf("platformGet() error = %v", err)
|
||||
}
|
||||
if got != secret {
|
||||
t.Fatalf("platformGet() = %q, want %q", got, secret)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlatformSetPrefersExistingFileMasterKey verifies writes stay on the file-based
|
||||
// master key path once the fallback master key already exists.
|
||||
func TestPlatformSetPrefersExistingFileMasterKey(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
|
||||
origGet := keyringGet
|
||||
origSet := keyringSet
|
||||
keyringGet = func(service, user string) (string, error) {
|
||||
t.Fatalf("keyringGet should not be called when file master key exists")
|
||||
return "", nil
|
||||
}
|
||||
keyringSet = func(service, user, password string) error {
|
||||
t.Fatalf("keyringSet should not be called when file master key exists")
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
keyringGet = origGet
|
||||
keyringSet = origSet
|
||||
})
|
||||
|
||||
service := "test-service"
|
||||
account := "test-account"
|
||||
secret := "secret-value"
|
||||
|
||||
dir := StorageDir(service)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
fileKey := make([]byte, masterKeyBytes)
|
||||
for i := range fileKey {
|
||||
fileKey[i] = byte(i + 1)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, fileMasterKeyName), fileKey, 0600); err != nil {
|
||||
t.Fatalf("WriteFile(master key) error = %v", err)
|
||||
}
|
||||
|
||||
if err := platformSet(service, account, secret); err != nil {
|
||||
t.Fatalf("platformSet() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := platformGet(service, account)
|
||||
if err != nil {
|
||||
t.Fatalf("platformGet() error = %v", err)
|
||||
}
|
||||
if got != secret {
|
||||
t.Fatalf("platformGet() = %q, want %q", got, secret)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user