Files
larksuite-cli/internal/auth/device_flow.go
2026-06-12 17:49:33 +08:00

299 lines
8.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/larksuite/cli/internal/core"
)
// DeviceAuthResponse is the response from the device authorization endpoint.
type DeviceAuthResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationUri string `json:"verification_uri"`
VerificationUriComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
// DeviceFlowTokenData contains the token data from a successful device flow.
type DeviceFlowTokenData struct {
AccessToken string
RefreshToken string
ExpiresIn int
RefreshExpiresIn int
Scope string
}
// DeviceFlowResult is the result of polling the token endpoint.
type DeviceFlowResult struct {
OK bool
Token *DeviceFlowTokenData
Error string
Message string
}
// OAuthEndpoints contains the OAuth endpoint URLs.
type OAuthEndpoints struct {
DeviceAuthorization string
Revoke string
Token string
}
// ResolveOAuthEndpoints resolves OAuth endpoint URLs based on brand.
func ResolveOAuthEndpoints(brand core.LarkBrand) OAuthEndpoints {
ep := core.ResolveEndpoints(brand)
return OAuthEndpoints{
DeviceAuthorization: ep.Accounts + PathDeviceAuthorization,
Revoke: ep.Accounts + PathOAuthRevoke,
Token: ep.Open + PathOAuthTokenV2,
}
}
// RequestDeviceAuthorization requests a device authorization code.
func RequestDeviceAuthorization(httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, scope string, errOut io.Writer) (*DeviceAuthResponse, error) {
if errOut == nil {
errOut = io.Discard
}
endpoints := ResolveOAuthEndpoints(brand)
if !strings.Contains(scope, "offline_access") {
if scope != "" {
scope = scope + " offline_access"
} else {
scope = "offline_access"
}
}
basicAuth := base64.StdEncoding.EncodeToString([]byte(appId + ":" + appSecret))
form := url.Values{}
form.Set("client_id", appId)
form.Set("scope", scope)
req, err := http.NewRequest("POST", endpoints.DeviceAuthorization, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+basicAuth)
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
logHTTPResponse(resp)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("Device authorization failed: read body: %v", err)
}
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err != nil {
return nil, fmt.Errorf("Device authorization failed: HTTP %d response not JSON", resp.StatusCode)
}
_, hasError := data["error"]
if resp.StatusCode >= 400 || hasError {
msg := getStr(data, "error_description")
if msg == "" {
msg = getStr(data, "error")
}
if msg == "" {
msg = "Unknown error"
}
return nil, fmt.Errorf("Device authorization failed: %s", msg)
}
expiresIn := getInt(data, "expires_in", 240)
interval := getInt(data, "interval", 5)
verificationUri := getStr(data, "verification_uri")
verificationUriComplete := getStr(data, "verification_uri_complete")
if verificationUriComplete == "" {
verificationUriComplete = verificationUri
}
return &DeviceAuthResponse{
DeviceCode: getStr(data, "device_code"),
UserCode: getStr(data, "user_code"),
VerificationUri: verificationUri,
VerificationUriComplete: verificationUriComplete,
ExpiresIn: expiresIn,
Interval: interval,
}, nil
}
// PollDeviceToken polls the token endpoint until authorization completes or times out.
func PollDeviceToken(ctx context.Context, httpClient *http.Client, appId, appSecret string, brand core.LarkBrand, deviceCode string, interval, expiresIn int, errOut io.Writer) *DeviceFlowResult {
if errOut == nil {
errOut = io.Discard
}
if interval < 1 {
interval = 5
}
const maxPollInterval = 60
const maxPollAttempts = 600
endpoints := ResolveOAuthEndpoints(brand)
deadline := time.Now().Add(time.Duration(expiresIn) * time.Second)
currentInterval := interval
attempts := 0
for time.Now().Before(deadline) && attempts < maxPollAttempts {
attempts++
if ctx.Err() != nil {
return &DeviceFlowResult{OK: false, Error: "expired_token", Message: "Polling was cancelled"}
}
select {
case <-time.After(time.Duration(currentInterval) * time.Second):
case <-ctx.Done():
return &DeviceFlowResult{OK: false, Error: "expired_token", Message: "Polling was cancelled"}
}
form := url.Values{}
form.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
form.Set("device_code", deviceCode)
form.Set("client_id", appId)
form.Set("client_secret", appSecret)
req, err := http.NewRequest("POST", endpoints.Token, strings.NewReader(form.Encode()))
if err != nil {
continue
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := httpClient.Do(req)
if err != nil {
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: poll network error: %v\n", err)
currentInterval = minInt(currentInterval+1, maxPollInterval)
continue
}
logHTTPResponse(resp)
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: poll read error: %v\n", err)
currentInterval = minInt(currentInterval+1, maxPollInterval)
continue
}
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err != nil {
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: poll parse error: %v\n", err)
currentInterval = minInt(currentInterval+1, maxPollInterval)
continue
}
errStr := getStr(data, "error")
if errStr == "" && getStr(data, "access_token") != "" {
fmt.Fprintf(errOut, "[lark-cli] device-flow: token response received\n")
refreshToken := getStr(data, "refresh_token")
tokenExpiresIn := getInt(data, "expires_in", 7200)
refreshExpiresIn := getInt(data, "refresh_token_expires_in", 604800)
if refreshToken == "" {
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: no refresh_token in response\n")
refreshExpiresIn = tokenExpiresIn
}
return &DeviceFlowResult{
OK: true,
Token: &DeviceFlowTokenData{
AccessToken: getStr(data, "access_token"),
RefreshToken: refreshToken,
ExpiresIn: tokenExpiresIn,
RefreshExpiresIn: refreshExpiresIn,
Scope: getStr(data, "scope"),
},
}
}
switch errStr {
case "authorization_pending":
continue
case "slow_down":
currentInterval = minInt(currentInterval+5, maxPollInterval)
fmt.Fprintf(errOut, "[lark-cli] device-flow: slow_down, interval increased to %ds\n", currentInterval)
continue
case "access_denied":
msg := getStr(data, "error_description")
if msg == "" {
msg = "Authorization denied by user"
}
return &DeviceFlowResult{OK: false, Error: "access_denied", Message: msg}
case "expired_token", "invalid_grant":
msg := getStr(data, "error_description")
if msg == "" {
msg = "Device code expired, please try again"
}
return &DeviceFlowResult{OK: false, Error: "expired_token", Message: msg}
}
desc := getStr(data, "error_description")
if desc == "" {
desc = errStr
}
if desc == "" {
desc = "Unknown error"
}
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: unexpected error: error=%s, desc=%s\n", errStr, desc)
return &DeviceFlowResult{OK: false, Error: "expired_token", Message: desc}
}
if attempts >= maxPollAttempts {
fmt.Fprintf(errOut, "[lark-cli] [WARN] device-flow: max poll attempts (%d) reached\n", maxPollAttempts)
}
return &DeviceFlowResult{OK: false, Error: "expired_token", Message: "Authorization timed out, please try again"}
}
// helpers
// minInt returns the smaller of a or b.
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// getStr retrieves a string value from a map, returning an empty string if not found or not a string.
func getStr(m map[string]interface{}, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// getInt retrieves an integer value from a map, returning a fallback value if not found or not a number.
func getInt(m map[string]interface{}, key string, fallback int) int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
}
}
return fallback
}