fix(auth): enforce remote manual auth state (#187)
* fix(gmail): fallback to send-as list for display name * refactor(gmail): remove dead code in primarySendAsDisplayNameFromList The condition `primary == nil && sa.IsPrimary` inside the email-matching block can never be true because `primary` is already unconditionally set to `sa` when `sa.IsPrimary` is true earlier in the same loop iteration. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * test(gmail): add --from display name fallback to list test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(auth): persist manual oauth state * feat(cli): add remote manual auth flow * fix(auth): enforce remote manual auth state * fix(auth): satisfy lint for manual auth flow * fix(auth): harden remote manual auth state cache * chore: update changelog for remote manual auth (#187) (thanks @salmonumbrella) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Peter Steinberger <steipete@gmail.com>
This commit is contained in:
parent
e3cb940780
commit
2df8ece2f6
@ -6,6 +6,11 @@
|
||||
|
||||
- Gmail: add `--exclude-labels` to `watch serve` (defaults: `SPAM,TRASH`). (#194) — thanks @salmonumbrella.
|
||||
- Drive: share files with an entire Workspace domain via `drive share --to domain`. (#192) — thanks @Danielkweber.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Auth: improve remote/server-friendly manual OAuth flow (`auth add --remote`). (#187) — thanks @salmonumbrella.
|
||||
|
||||
## 0.9.0 - 2026-01-22
|
||||
|
||||
### Highlights
|
||||
|
||||
15
README.md
15
README.md
@ -106,6 +106,21 @@ gog auth add you@gmail.com
|
||||
|
||||
This will open a browser window for OAuth authorization. The refresh token is stored securely in your system keychain.
|
||||
|
||||
Headless / remote server flow (no browser on the server):
|
||||
|
||||
```bash
|
||||
# Step 1: print auth URL (open it locally in a browser)
|
||||
gog auth add you@gmail.com --services user --remote --step 1
|
||||
|
||||
# Step 2: paste the full redirect URL from your browser address bar
|
||||
gog auth add you@gmail.com --services user --remote --step 2 --auth-url 'http://localhost:1/?code=...&state=...'
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The `state` is cached on disk for a short time (about 10 minutes). If it expires, rerun step 1.
|
||||
- Remote step 2 requires a redirect URL that includes `state` (state check mandatory).
|
||||
|
||||
### 4. Test Authentication
|
||||
|
||||
```bash
|
||||
|
||||
@ -99,6 +99,9 @@ Implementation: `internal/secrets/store.go`.
|
||||
|
||||
- Desktop OAuth 2.0 flow using local HTTP redirect on an ephemeral port.
|
||||
- Supports a browserless/manual flow (paste redirect URL) for headless environments.
|
||||
- Supports a remote/server-friendly 2-step manual flow:
|
||||
- Step 1 prints an auth URL (`gog auth add ... --remote --step 1`)
|
||||
- Step 2 exchanges the pasted redirect URL and requires `state` validation (`--remote --step 2 --auth-url ...`)
|
||||
- Refresh token issuance:
|
||||
- requests `access_type=offline`
|
||||
- supports `--force-consent` to force the consent prompt when Google doesn't return a refresh token
|
||||
@ -119,6 +122,7 @@ Scope selection note:
|
||||
- `credentials-<client>.json` (OAuth client id/secret; named clients)
|
||||
- State:
|
||||
- `state/gmail-watch/<account>.json` (Gmail watch state)
|
||||
- `oauth-manual-state-<state>.json` (temporary manual OAuth state cache; expires quickly; no tokens)
|
||||
- Secrets:
|
||||
- refresh tokens in keyring
|
||||
|
||||
@ -148,7 +152,7 @@ Flag aliases:
|
||||
- `gog auth credentials <credentials.json|->`
|
||||
- `gog auth credentials list`
|
||||
- `gog --client <name> auth credentials <credentials.json|->`
|
||||
- `gog auth add <email> [--services user|all|gmail,calendar,classroom,drive,docs,contacts,tasks,sheets,people,groups] [--readonly] [--drive-scope full|readonly|file] [--manual] [--force-consent]`
|
||||
- `gog auth add <email> [--services user|all|gmail,calendar,classroom,drive,docs,contacts,tasks,sheets,people,groups] [--readonly] [--drive-scope full|readonly|file] [--manual] [--remote] [--step 1|2] [--auth-url URL] [--timeout DURATION] [--force-consent]`
|
||||
- `gog auth services [--markdown]`
|
||||
- `gog auth keep <email> --key <service-account.json>` (Google Keep; Workspace only)
|
||||
- `gog auth list`
|
||||
|
||||
@ -26,6 +26,7 @@ var (
|
||||
checkRefreshToken = googleauth.CheckRefreshToken
|
||||
ensureKeychainAccess = secrets.EnsureKeychainAccess
|
||||
fetchAuthorizedEmail = googleauth.EmailForRefreshToken
|
||||
manualAuthURL = googleauth.ManualAuthURL
|
||||
)
|
||||
|
||||
func ensureKeychainAccessIfNeeded() error {
|
||||
@ -479,12 +480,17 @@ func (c *AuthTokensImportCmd) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
type AuthAddCmd struct {
|
||||
Email string `arg:"" name:"email" help:"Email"`
|
||||
Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"`
|
||||
ForceConsent bool `name:"force-consent" help:"Force consent screen to obtain a refresh token"`
|
||||
ServicesCSV string `name:"services" help:"Services to authorize: user|all or comma-separated ${auth_services} (Keep uses service account: gog auth service-account set)" default:"user"`
|
||||
Readonly bool `name:"readonly" help:"Use read-only scopes where available (still includes OIDC identity scopes)"`
|
||||
DriveScope string `name:"drive-scope" help:"Drive scope mode: full|readonly|file" enum:"full,readonly,file" default:"full"`
|
||||
Email string `arg:"" name:"email" help:"Email"`
|
||||
Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"`
|
||||
Remote bool `name:"remote" help:"Remote/server-friendly manual flow (print URL, then exchange code)"`
|
||||
Step int `name:"step" help:"Remote auth step: 1=print URL, 2=exchange code"`
|
||||
AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow; required for --remote --step 2)"`
|
||||
AuthCode string `name:"auth-code" hidden:"" help:"UNSAFE: Authorization code from browser (manual flow; skips state check; not valid with --remote)"`
|
||||
Timeout time.Duration `name:"timeout" help:"Authorization timeout (manual flows default to 5m)"`
|
||||
ForceConsent bool `name:"force-consent" help:"Force consent screen to obtain a refresh token"`
|
||||
ServicesCSV string `name:"services" help:"Services to authorize: user|all or comma-separated ${auth_services} (Keep uses service account: gog auth service-account set)" default:"user"`
|
||||
Readonly bool `name:"readonly" help:"Use read-only scopes where available (still includes OIDC identity scopes)"`
|
||||
DriveScope string `name:"drive-scope" help:"Drive scope mode: full|readonly|file" enum:"full,readonly,file" default:"full"`
|
||||
}
|
||||
|
||||
func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
@ -515,6 +521,69 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
authURL := strings.TrimSpace(c.AuthURL)
|
||||
authCode := strings.TrimSpace(c.AuthCode)
|
||||
if authURL != "" && authCode != "" {
|
||||
return usage("cannot combine --auth-url with --auth-code")
|
||||
}
|
||||
if c.Step != 0 && c.Step != 1 && c.Step != 2 {
|
||||
return usage("step must be 1 or 2")
|
||||
}
|
||||
if c.Step != 0 && !c.Remote {
|
||||
return usage("--step requires --remote")
|
||||
}
|
||||
|
||||
manual := c.Manual || c.Remote || authURL != "" || authCode != ""
|
||||
|
||||
if c.Remote {
|
||||
step := c.Step
|
||||
if step == 0 {
|
||||
if authURL != "" || authCode != "" {
|
||||
step = 2
|
||||
} else {
|
||||
step = 1
|
||||
}
|
||||
}
|
||||
switch step {
|
||||
case 1:
|
||||
if authURL != "" || authCode != "" {
|
||||
return usage("remote step 1 does not accept --auth-url or --auth-code")
|
||||
}
|
||||
result, manualErr := manualAuthURL(ctx, googleauth.AuthorizeOptions{
|
||||
Services: services,
|
||||
Scopes: scopes,
|
||||
Manual: true,
|
||||
ForceConsent: c.ForceConsent,
|
||||
Client: client,
|
||||
})
|
||||
if manualErr != nil {
|
||||
return manualErr
|
||||
}
|
||||
if outfmt.IsJSON(ctx) {
|
||||
return outfmt.WriteJSON(os.Stdout, map[string]any{
|
||||
"auth_url": result.URL,
|
||||
"state_reused": result.StateReused,
|
||||
})
|
||||
}
|
||||
u.Out().Printf("auth_url\t%s", result.URL)
|
||||
u.Out().Printf("state_reused\t%t", result.StateReused)
|
||||
u.Err().Println("Run again with --remote --step 2 --auth-url <redirect-url>")
|
||||
return nil
|
||||
case 2:
|
||||
if authCode != "" {
|
||||
return usage("--auth-code is not valid with --remote (state check is mandatory)")
|
||||
}
|
||||
if authURL == "" {
|
||||
return usage("remote step 2 requires --auth-url")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timeout := c.Timeout
|
||||
if timeout == 0 && manual {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
// Pre-flight: ensure keychain is accessible before starting OAuth
|
||||
if keychainErr := ensureKeychainAccessIfNeeded(); keychainErr != nil {
|
||||
return fmt.Errorf("keychain access: %w", keychainErr)
|
||||
@ -523,9 +592,13 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
refreshToken, err := authorizeGoogle(ctx, googleauth.AuthorizeOptions{
|
||||
Services: services,
|
||||
Scopes: scopes,
|
||||
Manual: c.Manual,
|
||||
Manual: manual,
|
||||
ForceConsent: c.ForceConsent,
|
||||
Timeout: timeout,
|
||||
Client: client,
|
||||
AuthURL: authURL,
|
||||
AuthCode: authCode,
|
||||
RequireState: c.Remote,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -473,6 +473,180 @@ func TestAuthAddCmd_SheetsDriveScopeFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAddCmd_RemoteStep1_PrintsAuthURL(t *testing.T) {
|
||||
origManualURL := manualAuthURL
|
||||
origAuth := authorizeGoogle
|
||||
origKeychain := ensureKeychainAccess
|
||||
t.Cleanup(func() {
|
||||
manualAuthURL = origManualURL
|
||||
authorizeGoogle = origAuth
|
||||
ensureKeychainAccess = origKeychain
|
||||
})
|
||||
|
||||
manualCalled := false
|
||||
manualAuthURL = func(context.Context, googleauth.AuthorizeOptions) (googleauth.ManualAuthURLResult, error) {
|
||||
manualCalled = true
|
||||
return googleauth.ManualAuthURLResult{
|
||||
URL: "https://example.com/auth",
|
||||
StateReused: true,
|
||||
}, nil
|
||||
}
|
||||
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) {
|
||||
t.Fatal("authorizeGoogle should not be called in remote step 1")
|
||||
return "", nil
|
||||
}
|
||||
ensureKeychainAccess = func() error {
|
||||
t.Fatal("keychain access should not be checked in remote step 1")
|
||||
return nil
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{
|
||||
"auth",
|
||||
"add",
|
||||
"user@example.com",
|
||||
"--services",
|
||||
"gmail",
|
||||
"--remote",
|
||||
"--step",
|
||||
"1",
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
if !manualCalled {
|
||||
t.Fatalf("expected manualAuthURL to be called")
|
||||
}
|
||||
if !strings.Contains(out, "auth_url\thttps://example.com/auth") {
|
||||
t.Fatalf("unexpected output: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "state_reused\ttrue") {
|
||||
t.Fatalf("expected state_reused output, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAddCmd_RemoteStep2_RejectsAuthCode(t *testing.T) {
|
||||
err := Execute([]string{
|
||||
"auth",
|
||||
"add",
|
||||
"user@example.com",
|
||||
"--services",
|
||||
"gmail",
|
||||
"--remote",
|
||||
"--step",
|
||||
"2",
|
||||
"--auth-code",
|
||||
"abc123",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var ee *ExitError
|
||||
if !errors.As(err, &ee) || ee.Code != 2 {
|
||||
t.Fatalf("expected exit code 2, got %T %#v", err, err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--auth-code is not valid with --remote") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAddCmd_RemoteStep2_PassesAuthURL(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }
|
||||
|
||||
var gotOpts googleauth.AuthorizeOptions
|
||||
authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) {
|
||||
gotOpts = opts
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) {
|
||||
return "user@example.com", nil
|
||||
}
|
||||
|
||||
if err := Execute([]string{
|
||||
"auth",
|
||||
"add",
|
||||
"user@example.com",
|
||||
"--services",
|
||||
"gmail",
|
||||
"--remote",
|
||||
"--step",
|
||||
"2",
|
||||
"--auth-url",
|
||||
"http://localhost:1/?code=abc&state=state123",
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if !gotOpts.Manual {
|
||||
t.Fatalf("expected manual auth in remote step 2")
|
||||
}
|
||||
if !gotOpts.RequireState {
|
||||
t.Fatalf("expected require state in remote step 2")
|
||||
}
|
||||
if gotOpts.AuthURL == "" {
|
||||
t.Fatalf("expected auth URL to be passed through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAddCmd_AuthCode_PassesThrough(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }
|
||||
|
||||
var gotOpts googleauth.AuthorizeOptions
|
||||
authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) {
|
||||
gotOpts = opts
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) {
|
||||
return "user@example.com", nil
|
||||
}
|
||||
|
||||
if err := Execute([]string{
|
||||
"auth",
|
||||
"add",
|
||||
"user@example.com",
|
||||
"--services",
|
||||
"gmail",
|
||||
"--auth-code",
|
||||
"abc123",
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute: %v", err)
|
||||
}
|
||||
|
||||
if !gotOpts.Manual {
|
||||
t.Fatalf("expected manual auth when auth-code is provided")
|
||||
}
|
||||
if gotOpts.AuthCode != "abc123" {
|
||||
t.Fatalf("expected auth-code to be passed through, got %q", gotOpts.AuthCode)
|
||||
}
|
||||
}
|
||||
|
||||
func containsStringInSlice(items []string, want string) bool {
|
||||
for _, it := range items {
|
||||
if it == want {
|
||||
|
||||
@ -119,15 +119,21 @@ func (c *GmailSendCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
sendingEmail = c.From
|
||||
fromAddr = c.From
|
||||
// Include display name if set
|
||||
if sa.DisplayName != "" {
|
||||
fromAddr = sa.DisplayName + " <" + c.From + ">"
|
||||
displayName := strings.TrimSpace(sa.DisplayName)
|
||||
if displayName == "" {
|
||||
if fallback, listErr := sendAsDisplayNameFromList(ctx, svc, c.From); listErr == nil {
|
||||
displayName = fallback
|
||||
}
|
||||
}
|
||||
if displayName != "" {
|
||||
fromAddr = displayName + " <" + c.From + ">"
|
||||
}
|
||||
} else {
|
||||
// No --from specified: look up the primary account's send-as settings
|
||||
// to get the display name
|
||||
sa, saErr := svc.Users.Settings.SendAs.Get("me", account).Context(ctx).Do()
|
||||
if saErr == nil && sa.DisplayName != "" {
|
||||
fromAddr = sa.DisplayName + " <" + account + ">"
|
||||
displayName := primarySendAsDisplayName(ctx, svc, account)
|
||||
if displayName != "" {
|
||||
fromAddr = displayName + " <" + account + ">"
|
||||
}
|
||||
// If lookup fails, we just use the plain email address (no error)
|
||||
}
|
||||
@ -217,6 +223,78 @@ func (c *GmailSendCmd) resolveTrackingConfig(account string, toRecipients, ccRec
|
||||
return trackingCfg, nil
|
||||
}
|
||||
|
||||
func primarySendAsDisplayName(ctx context.Context, svc *gmail.Service, account string) string {
|
||||
account = strings.TrimSpace(account)
|
||||
if account == "" || svc == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sa, err := svc.Users.Settings.SendAs.Get("me", account).Context(ctx).Do()
|
||||
if err == nil {
|
||||
if displayName := strings.TrimSpace(sa.DisplayName); displayName != "" {
|
||||
return displayName
|
||||
}
|
||||
}
|
||||
|
||||
displayName, err := primarySendAsDisplayNameFromList(ctx, svc, account)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return displayName
|
||||
}
|
||||
|
||||
func sendAsDisplayNameFromList(ctx context.Context, svc *gmail.Service, email string) (string, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || svc == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
resp, err := svc.Users.Settings.SendAs.List("me").Context(ctx).Do()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
needle := strings.ToLower(email)
|
||||
for _, sa := range resp.SendAs {
|
||||
if strings.ToLower(strings.TrimSpace(sa.SendAsEmail)) == needle {
|
||||
return strings.TrimSpace(sa.DisplayName), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func primarySendAsDisplayNameFromList(ctx context.Context, svc *gmail.Service, account string) (string, error) {
|
||||
account = strings.TrimSpace(account)
|
||||
if account == "" || svc == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
resp, err := svc.Users.Settings.SendAs.List("me").Context(ctx).Do()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
needle := strings.ToLower(account)
|
||||
var primary *gmail.SendAs
|
||||
for _, sa := range resp.SendAs {
|
||||
if sa.IsPrimary {
|
||||
primary = sa
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(sa.SendAsEmail)) == needle {
|
||||
if displayName := strings.TrimSpace(sa.DisplayName); displayName != "" {
|
||||
return displayName, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if primary != nil {
|
||||
return strings.TrimSpace(primary.DisplayName), nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func buildSendBatches(toRecipients, ccRecipients, bccRecipients []string, track, trackSplit bool) []sendBatch {
|
||||
totalRecipients := len(toRecipients) + len(ccRecipients) + len(bccRecipients)
|
||||
if track && trackSplit && totalRecipients > 1 {
|
||||
|
||||
@ -316,6 +316,82 @@ func TestGmailSendCmd_RunJSON_WithFrom(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailSendCmd_RunJSON_WithFromDisplayNameFallbackToList(t *testing.T) {
|
||||
origNew := newGmailService
|
||||
t.Cleanup(func() { newGmailService = origNew })
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/gmail/v1")
|
||||
switch {
|
||||
case r.Method == http.MethodGet && path == "/users/me/settings/sendAs/alias@example.com":
|
||||
// Return send-as settings with empty display name but valid verification.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sendAsEmail": "alias@example.com",
|
||||
"displayName": "",
|
||||
"verificationStatus": "accepted",
|
||||
})
|
||||
return
|
||||
case r.Method == http.MethodGet && path == "/users/me/settings/sendAs":
|
||||
// Fallback list endpoint returns the alias with a populated display name.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sendAs": []map[string]any{
|
||||
{
|
||||
"sendAsEmail": "alias@example.com",
|
||||
"displayName": "Alias From List",
|
||||
"verificationStatus": "accepted",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case r.Method == http.MethodPost && path == "/users/me/messages/send":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "m2b",
|
||||
"threadId": "t2b",
|
||||
})
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
svc, err := gmail.NewService(context.Background(),
|
||||
option.WithoutAuthentication(),
|
||||
option.WithHTTPClient(srv.Client()),
|
||||
option.WithEndpoint(srv.URL+"/"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService: %v", err)
|
||||
}
|
||||
newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil }
|
||||
|
||||
u, err := ui.New(ui.Options{Stdout: os.Stdout, Stderr: os.Stderr, Color: "never"})
|
||||
if err != nil {
|
||||
t.Fatalf("ui.New: %v", err)
|
||||
}
|
||||
ctx := outfmt.WithMode(ui.WithUI(context.Background(), u), outfmt.Mode{JSON: true})
|
||||
|
||||
cmd := &GmailSendCmd{
|
||||
To: "a@example.com",
|
||||
From: "alias@example.com",
|
||||
Subject: "Hello",
|
||||
Body: "Body",
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
if err := cmd.Run(ctx, &RootFlags{Account: "a@b.com"}); err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
})
|
||||
if !strings.Contains(out, "\"from\"") || !strings.Contains(out, "Alias From List <alias@example.com>") {
|
||||
t.Fatalf("expected from with display name from list fallback, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayName(t *testing.T) {
|
||||
origNew := newGmailService
|
||||
t.Cleanup(func() { newGmailService = origNew })
|
||||
@ -380,6 +456,81 @@ func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayNameFallbackToList(t *testing.T) {
|
||||
origNew := newGmailService
|
||||
t.Cleanup(func() { newGmailService = origNew })
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/gmail/v1")
|
||||
switch {
|
||||
case r.Method == http.MethodGet && path == "/users/me/settings/sendAs/a@b.com":
|
||||
// Simulate missing display name in get response.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sendAsEmail": "a@b.com",
|
||||
"displayName": "",
|
||||
"verificationStatus": "accepted",
|
||||
})
|
||||
return
|
||||
case r.Method == http.MethodGet && path == "/users/me/settings/sendAs":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"sendAs": []map[string]any{
|
||||
{
|
||||
"sendAsEmail": "a@b.com",
|
||||
"displayName": "Primary User",
|
||||
"verificationStatus": "accepted",
|
||||
"isPrimary": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
case r.Method == http.MethodPost && path == "/users/me/messages/send":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "m3b",
|
||||
"threadId": "t3b",
|
||||
})
|
||||
return
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
svc, err := gmail.NewService(context.Background(),
|
||||
option.WithoutAuthentication(),
|
||||
option.WithHTTPClient(srv.Client()),
|
||||
option.WithEndpoint(srv.URL+"/"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewService: %v", err)
|
||||
}
|
||||
newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil }
|
||||
|
||||
u, err := ui.New(ui.Options{Stdout: os.Stdout, Stderr: os.Stderr, Color: "never"})
|
||||
if err != nil {
|
||||
t.Fatalf("ui.New: %v", err)
|
||||
}
|
||||
ctx := outfmt.WithMode(ui.WithUI(context.Background(), u), outfmt.Mode{JSON: true})
|
||||
|
||||
cmd := &GmailSendCmd{
|
||||
To: "recipient@example.com",
|
||||
Subject: "Hello",
|
||||
Body: "Body",
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
if err := cmd.Run(ctx, &RootFlags{Account: "a@b.com"}); err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
})
|
||||
if !strings.Contains(out, "\"from\"") || !strings.Contains(out, "Primary User <a@b.com>") {
|
||||
t.Fatalf("expected from with display name, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailSendCmd_RunJSON_PrimaryAccountNoDisplayName(t *testing.T) {
|
||||
origNew := newGmailService
|
||||
t.Cleanup(func() { newGmailService = origNew })
|
||||
|
||||
234
internal/googleauth/manual_state.go
Normal file
234
internal/googleauth/manual_state.go
Normal file
@ -0,0 +1,234 @@
|
||||
package googleauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/gogcli/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
manualStateFilePrefix = "oauth-manual-state-"
|
||||
manualStateFileSuffix = ".json"
|
||||
)
|
||||
|
||||
var errEmptyManualAuthState = errors.New("empty manual auth state")
|
||||
|
||||
// manualStateTTL controls how long a stored manual auth state is considered valid.
|
||||
// This should be shorter than typical OAuth code expiration windows.
|
||||
const manualStateTTL = 10 * time.Minute
|
||||
|
||||
type manualState struct {
|
||||
State string `json:"state"`
|
||||
Client string `json:"client"`
|
||||
Scopes []string `json:"scopes"`
|
||||
ForceConsent bool `json:"force_consent,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
var (
|
||||
manualStateDirFn = manualStateDir
|
||||
manualStateNowFn = time.Now
|
||||
)
|
||||
|
||||
func manualStateDir() (string, error) {
|
||||
dir, err := config.EnsureDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ensure config dir: %w", err)
|
||||
}
|
||||
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func manualStatePathFor(state string) (string, error) {
|
||||
dir, err := manualStateDirFn()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return "", errEmptyManualAuthState
|
||||
}
|
||||
|
||||
return filepath.Join(dir, manualStateFilePrefix+state+manualStateFileSuffix), nil
|
||||
}
|
||||
|
||||
func isManualStateFilename(name string) (state string, ok bool) {
|
||||
if !strings.HasPrefix(name, manualStateFilePrefix) || !strings.HasSuffix(name, manualStateFileSuffix) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
state = strings.TrimSuffix(strings.TrimPrefix(name, manualStateFilePrefix), manualStateFileSuffix)
|
||||
if strings.TrimSpace(state) == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return state, true
|
||||
}
|
||||
|
||||
func loadManualState(client string, scopes []string, forceConsent bool) (string, bool, error) {
|
||||
dir, err := manualStateDirFn()
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("read manual auth state dir: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
bestState string
|
||||
bestCreated time.Time
|
||||
)
|
||||
|
||||
for _, ent := range entries {
|
||||
if ent.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
state, ok := isManualStateFilename(ent.Name())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, ent.Name())
|
||||
|
||||
st, valid, loadErr := loadManualStateByPath(path)
|
||||
if loadErr != nil {
|
||||
return "", false, loadErr
|
||||
}
|
||||
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
|
||||
if st.Client != client || st.ForceConsent != forceConsent || !scopesEqual(st.Scopes, scopes) {
|
||||
continue
|
||||
}
|
||||
|
||||
if bestState == "" || st.CreatedAt.After(bestCreated) {
|
||||
bestState = state
|
||||
bestCreated = st.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
if bestState == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
return bestState, true, nil
|
||||
}
|
||||
|
||||
func loadManualStateByPath(path string) (manualState, bool, error) {
|
||||
data, err := os.ReadFile(path) //nolint:gosec // config path
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return manualState{}, false, nil
|
||||
}
|
||||
|
||||
return manualState{}, false, fmt.Errorf("read manual auth state: %w", err)
|
||||
}
|
||||
|
||||
var st manualState
|
||||
if err := json.Unmarshal(data, &st); err != nil {
|
||||
_ = os.Remove(path)
|
||||
return manualState{}, false, nil //nolint:nilerr // invalid state should be treated as a cache miss
|
||||
}
|
||||
|
||||
if st.State == "" {
|
||||
_ = os.Remove(path)
|
||||
return manualState{}, false, nil
|
||||
}
|
||||
|
||||
if manualStateNowFn().Sub(st.CreatedAt) > manualStateTTL {
|
||||
_ = os.Remove(path)
|
||||
return manualState{}, false, nil
|
||||
}
|
||||
|
||||
return st, true, nil
|
||||
}
|
||||
|
||||
func saveManualState(client string, scopes []string, forceConsent bool, state string) error {
|
||||
path, err := manualStatePathFor(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st := manualState{
|
||||
State: state,
|
||||
Client: client,
|
||||
Scopes: normalizeScopes(scopes),
|
||||
ForceConsent: forceConsent,
|
||||
CreatedAt: manualStateNowFn().UTC(),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(st, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode manual auth state: %w", err)
|
||||
}
|
||||
|
||||
data = append(data, '\n')
|
||||
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o600); err != nil {
|
||||
return fmt.Errorf("write manual auth state: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
return fmt.Errorf("commit manual auth state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func clearManualState(state string) error {
|
||||
path, err := manualStatePathFor(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("remove manual auth state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeScopes(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := append([]string(nil), scopes...)
|
||||
sort.Strings(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func scopesEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
na := normalizeScopes(a)
|
||||
nb := normalizeScopes(b)
|
||||
|
||||
for i := range na {
|
||||
if na[i] != nb[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
79
internal/googleauth/manual_state_test.go
Normal file
79
internal/googleauth/manual_state_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package googleauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/steipete/gogcli/internal/config"
|
||||
)
|
||||
|
||||
func TestManualAuthURL_ReusesState(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
origState := randomStateFn
|
||||
|
||||
t.Cleanup(func() {
|
||||
readClientCredentials = origRead
|
||||
oauthEndpoint = origEndpoint
|
||||
randomStateFn = origState
|
||||
})
|
||||
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
oauthEndpoint = oauth2EndpointForTest("http://example.com")
|
||||
stateCalls := 0
|
||||
randomStateFn = func() (string, error) {
|
||||
stateCalls++
|
||||
if stateCalls == 1 {
|
||||
return "state1", nil
|
||||
}
|
||||
|
||||
return "state2", nil
|
||||
}
|
||||
|
||||
res1, err := ManualAuthURL(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ManualAuthURL: %v", err)
|
||||
}
|
||||
|
||||
res2, err := ManualAuthURL(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ManualAuthURL second: %v", err)
|
||||
}
|
||||
|
||||
state1 := authURLState(t, res1.URL)
|
||||
|
||||
state2 := authURLState(t, res2.URL)
|
||||
if state1 != "state1" || state2 != "state1" {
|
||||
t.Fatalf("expected reused state, got state1=%q state2=%q", state1, state2)
|
||||
}
|
||||
|
||||
if !res2.StateReused {
|
||||
t.Fatalf("expected state_reused true on second call")
|
||||
}
|
||||
|
||||
if stateCalls != 1 {
|
||||
t.Fatalf("expected randomStateFn called once, got %d", stateCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func authURLState(t *testing.T, rawURL string) string {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
return parsed.Query().Get("state")
|
||||
}
|
||||
@ -29,6 +29,14 @@ type AuthorizeOptions struct {
|
||||
ForceConsent bool
|
||||
Timeout time.Duration
|
||||
Client string
|
||||
AuthCode string
|
||||
AuthURL string
|
||||
RequireState bool
|
||||
}
|
||||
|
||||
type ManualAuthURLResult struct {
|
||||
URL string
|
||||
StateReused bool
|
||||
}
|
||||
|
||||
// postSuccessDisplaySeconds is the number of seconds the success page remains
|
||||
@ -51,12 +59,18 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
errAuthorization = errors.New("authorization error")
|
||||
errMissingCode = errors.New("missing code")
|
||||
errMissingScopes = errors.New("missing scopes")
|
||||
errNoCodeInURL = errors.New("no code found in URL")
|
||||
errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent")
|
||||
errStateMismatch = errors.New("state mismatch")
|
||||
errAuthorization = errors.New("authorization error")
|
||||
errMissingCode = errors.New("missing code")
|
||||
errMissingState = errors.New("missing state in redirect URL")
|
||||
errMissingScopes = errors.New("missing scopes")
|
||||
errNoCodeInURL = errors.New("no code found in URL")
|
||||
errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent")
|
||||
errManualStateMissing = errors.New("manual auth state missing; run remote step 1 again")
|
||||
errManualStateMismatch = errors.New("manual auth state mismatch; run remote step 1 again")
|
||||
errStateMismatch = errors.New("state mismatch")
|
||||
|
||||
errInvalidAuthorizeOptionsAuthURLAndCode = errors.New("cannot combine auth-url with auth-code")
|
||||
errInvalidAuthorizeOptionsAuthCodeWithState = errors.New("auth-code is not valid when state is required; provide auth-url")
|
||||
)
|
||||
|
||||
func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
|
||||
@ -64,19 +78,19 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
|
||||
opts.Timeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
if strings.TrimSpace(opts.AuthURL) != "" && strings.TrimSpace(opts.AuthCode) != "" {
|
||||
return "", errInvalidAuthorizeOptionsAuthURLAndCode
|
||||
}
|
||||
|
||||
if opts.RequireState && strings.TrimSpace(opts.AuthCode) != "" {
|
||||
return "", errInvalidAuthorizeOptionsAuthCodeWithState
|
||||
}
|
||||
|
||||
if len(opts.Scopes) == 0 {
|
||||
return "", errMissingScopes
|
||||
}
|
||||
|
||||
var creds config.ClientCredentials
|
||||
|
||||
if c, err := readClientCredentials(opts.Client); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
creds = c
|
||||
}
|
||||
|
||||
state, err := randomStateFn()
|
||||
creds, err := readClientCredentials(opts.Client)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -85,55 +99,173 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
|
||||
defer cancel()
|
||||
|
||||
if opts.Manual {
|
||||
redirectURI := "http://localhost:1"
|
||||
cfg := oauth2.Config{
|
||||
ClientID: creds.ClientID,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
Endpoint: oauthEndpoint,
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
authURL := cfg.AuthCodeURL(state, authURLParams(opts.ForceConsent)...)
|
||||
return authorizeManual(ctx, opts, creds)
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Visit this URL to authorize:")
|
||||
fmt.Fprintln(os.Stderr, authURL)
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "After authorizing, you'll be redirected to a localhost URL that won't load.")
|
||||
fmt.Fprintln(os.Stderr, "Copy the URL from your browser's address bar and paste it here.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
return authorizeServer(ctx, opts, creds)
|
||||
}
|
||||
|
||||
line, readErr := input.PromptLine(ctx, "Paste redirect URL (Enter or Ctrl-D): ")
|
||||
if readErr != nil && !errors.Is(readErr, os.ErrClosed) {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return "", fmt.Errorf("authorization canceled: %w", context.Canceled)
|
||||
}
|
||||
func authorizeManual(ctx context.Context, opts AuthorizeOptions, creds config.ClientCredentials) (string, error) {
|
||||
authURLInput := strings.TrimSpace(opts.AuthURL)
|
||||
authCodeInput := strings.TrimSpace(opts.AuthCode)
|
||||
|
||||
return "", fmt.Errorf("read redirect url: %w", readErr)
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
redirectURI := "http://localhost:1"
|
||||
cfg := oauth2.Config{
|
||||
ClientID: creds.ClientID,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
Endpoint: oauthEndpoint,
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
|
||||
code, gotState, parseErr := extractCodeAndState(line)
|
||||
if authURLInput != "" || authCodeInput != "" {
|
||||
return authorizeManualWithCode(ctx, opts, cfg, authURLInput, authCodeInput)
|
||||
}
|
||||
|
||||
return authorizeManualInteractive(ctx, opts, cfg)
|
||||
}
|
||||
|
||||
func authorizeManualWithCode(
|
||||
ctx context.Context,
|
||||
opts AuthorizeOptions,
|
||||
cfg oauth2.Config,
|
||||
authURLInput string,
|
||||
authCodeInput string,
|
||||
) (string, error) {
|
||||
code := strings.TrimSpace(authCodeInput)
|
||||
gotState := ""
|
||||
|
||||
if authURLInput != "" {
|
||||
parsedCode, parsedState, parseErr := extractCodeAndState(authURLInput)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
}
|
||||
|
||||
if gotState != "" && gotState != state {
|
||||
return "", errStateMismatch
|
||||
code = parsedCode
|
||||
gotState = parsedState
|
||||
|
||||
if opts.RequireState && gotState == "" {
|
||||
return "", errMissingState
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return "", errMissingCode
|
||||
}
|
||||
|
||||
if gotState != "" {
|
||||
if err := validateManualState(opts, gotState); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
tok, exchangeErr := cfg.Exchange(ctx, code)
|
||||
if exchangeErr != nil {
|
||||
return "", fmt.Errorf("exchange code: %w", exchangeErr)
|
||||
}
|
||||
|
||||
if tok.RefreshToken == "" {
|
||||
return "", errNoRefreshToken
|
||||
}
|
||||
|
||||
if gotState != "" {
|
||||
_ = clearManualState(gotState)
|
||||
}
|
||||
|
||||
return tok.RefreshToken, nil
|
||||
}
|
||||
|
||||
func authorizeManualInteractive(ctx context.Context, opts AuthorizeOptions, cfg oauth2.Config) (string, error) {
|
||||
setup, err := manualAuthSetup(opts, cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Visit this URL to authorize:")
|
||||
fmt.Fprintln(os.Stderr, setup.authURL)
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "After authorizing, you'll be redirected to a localhost URL that won't load.")
|
||||
fmt.Fprintln(os.Stderr, "Copy the URL from your browser's address bar and paste it here.")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
|
||||
line, readErr := input.PromptLine(ctx, "Paste redirect URL (Enter or Ctrl-D): ")
|
||||
if readErr != nil && !errors.Is(readErr, os.ErrClosed) {
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return "", fmt.Errorf("authorization canceled: %w", context.Canceled)
|
||||
}
|
||||
|
||||
var tok *oauth2.Token
|
||||
return "", fmt.Errorf("read redirect url: %w", readErr)
|
||||
}
|
||||
|
||||
if t, exchangeErr := cfg.Exchange(ctx, code); exchangeErr != nil {
|
||||
return "", fmt.Errorf("exchange code: %w", exchangeErr)
|
||||
} else {
|
||||
tok = t
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
code, gotState, parseErr := extractCodeAndState(line)
|
||||
if parseErr != nil {
|
||||
return "", parseErr
|
||||
}
|
||||
|
||||
if gotState != "" && gotState != setup.state {
|
||||
return "", errStateMismatch
|
||||
}
|
||||
|
||||
tok, exchangeErr := cfg.Exchange(ctx, code)
|
||||
if exchangeErr != nil {
|
||||
return "", fmt.Errorf("exchange code: %w", exchangeErr)
|
||||
}
|
||||
|
||||
if tok.RefreshToken == "" {
|
||||
return "", errNoRefreshToken
|
||||
}
|
||||
|
||||
_ = clearManualState(setup.state)
|
||||
|
||||
return tok.RefreshToken, nil
|
||||
}
|
||||
|
||||
func validateManualState(opts AuthorizeOptions, gotState string) error {
|
||||
if opts.RequireState {
|
||||
if gotState == "" {
|
||||
return errMissingState
|
||||
}
|
||||
}
|
||||
|
||||
if gotState == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
path, err := manualStatePathFor(gotState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st, ok, err := loadManualStateByPath(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if opts.RequireState {
|
||||
return errManualStateMissing
|
||||
}
|
||||
|
||||
if tok.RefreshToken == "" {
|
||||
return "", errNoRefreshToken
|
||||
return nil
|
||||
}
|
||||
|
||||
if st.Client != opts.Client || st.ForceConsent != opts.ForceConsent || !scopesEqual(st.Scopes, opts.Scopes) {
|
||||
if opts.RequireState {
|
||||
return errManualStateMismatch
|
||||
}
|
||||
|
||||
return tok.RefreshToken, nil
|
||||
return errStateMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func authorizeServer(ctx context.Context, opts AuthorizeOptions, creds config.ClientCredentials) (string, error) {
|
||||
state, err := randomStateFn()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
|
||||
@ -270,6 +402,70 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func ManualAuthURL(ctx context.Context, opts AuthorizeOptions) (ManualAuthURLResult, error) {
|
||||
if opts.Timeout <= 0 {
|
||||
opts.Timeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
if len(opts.Scopes) == 0 {
|
||||
return ManualAuthURLResult{}, errMissingScopes
|
||||
}
|
||||
|
||||
creds, err := readClientCredentials(opts.Client)
|
||||
if err != nil {
|
||||
return ManualAuthURLResult{}, err
|
||||
}
|
||||
|
||||
redirectURI := "http://localhost:1"
|
||||
cfg := oauth2.Config{
|
||||
ClientID: creds.ClientID,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
Endpoint: oauthEndpoint,
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
|
||||
setup, err := manualAuthSetup(opts, cfg)
|
||||
if err != nil {
|
||||
return ManualAuthURLResult{}, err
|
||||
}
|
||||
|
||||
return ManualAuthURLResult{
|
||||
URL: setup.authURL,
|
||||
StateReused: setup.reused,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type manualAuthSetupResult struct {
|
||||
authURL string
|
||||
state string
|
||||
reused bool
|
||||
}
|
||||
|
||||
func manualAuthSetup(opts AuthorizeOptions, cfg oauth2.Config) (manualAuthSetupResult, error) {
|
||||
state, reused, err := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent)
|
||||
if err != nil {
|
||||
return manualAuthSetupResult{}, err
|
||||
}
|
||||
|
||||
if !reused {
|
||||
state, err = randomStateFn()
|
||||
if err != nil {
|
||||
return manualAuthSetupResult{}, err
|
||||
}
|
||||
|
||||
if err := saveManualState(opts.Client, opts.Scopes, opts.ForceConsent, state); err != nil {
|
||||
return manualAuthSetupResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return manualAuthSetupResult{
|
||||
authURL: cfg.AuthCodeURL(state, authURLParams(opts.ForceConsent)...),
|
||||
state: state,
|
||||
reused: reused,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authURLParams(forceConsent bool) []oauth2.AuthCodeOption {
|
||||
opts := []oauth2.AuthCodeOption{
|
||||
oauth2.AccessTypeOffline,
|
||||
|
||||
@ -54,6 +54,18 @@ func newTokenServer(t *testing.T) *httptest.Server {
|
||||
}))
|
||||
}
|
||||
|
||||
func useTempManualStatePath(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
origDir := manualStateDirFn
|
||||
dir := t.TempDir()
|
||||
manualStateDirFn = func() (string, error) { return dir, nil }
|
||||
|
||||
t.Cleanup(func() {
|
||||
manualStateDirFn = origDir
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorize_MissingScopes(t *testing.T) {
|
||||
_, err := Authorize(context.Background(), AuthorizeOptions{})
|
||||
if err == nil || !strings.Contains(err.Error(), "missing scopes") {
|
||||
@ -72,6 +84,8 @@ func TestAuthorize_Manual_Success(t *testing.T) {
|
||||
randomStateFn = origState
|
||||
})
|
||||
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
@ -122,6 +136,7 @@ func TestAuthorize_Manual_Success_NoNewline(t *testing.T) {
|
||||
oauthEndpoint = origEndpoint
|
||||
randomStateFn = origState
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
@ -173,6 +188,7 @@ func TestAuthorize_Manual_CancelEOF(t *testing.T) {
|
||||
oauthEndpoint = origEndpoint
|
||||
randomStateFn = origState
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
@ -218,6 +234,7 @@ func TestAuthorize_Manual_StateMismatch(t *testing.T) {
|
||||
oauthEndpoint = origEndpoint
|
||||
randomStateFn = origState
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
@ -254,6 +271,150 @@ func TestAuthorize_Manual_StateMismatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_Manual_AuthCode(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
origState := randomStateFn
|
||||
|
||||
t.Cleanup(func() {
|
||||
readClientCredentials = origRead
|
||||
oauthEndpoint = origEndpoint
|
||||
randomStateFn = origState
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
stateCalled := false
|
||||
randomStateFn = func() (string, error) {
|
||||
stateCalled = true
|
||||
return "state123", nil
|
||||
}
|
||||
|
||||
tokenSrv := newTokenServer(t)
|
||||
defer tokenSrv.Close()
|
||||
oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL)
|
||||
|
||||
rt, err := Authorize(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
AuthCode: "abc",
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Authorize: %v", err)
|
||||
}
|
||||
|
||||
if rt != "rt" {
|
||||
t.Fatalf("unexpected refresh token: %q", rt)
|
||||
}
|
||||
|
||||
if stateCalled {
|
||||
t.Fatalf("unexpected state generation in auth-code flow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_Manual_AuthURL_RequireStateMissing(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
|
||||
t.Cleanup(func() {
|
||||
readClientCredentials = origRead
|
||||
oauthEndpoint = origEndpoint
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
oauthEndpoint = oauth2EndpointForTest("http://example.com")
|
||||
|
||||
_, err := Authorize(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
AuthURL: "http://localhost:1/?code=abc",
|
||||
RequireState: true,
|
||||
Client: "default",
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
|
||||
if !errors.Is(err, errMissingState) {
|
||||
t.Fatalf("expected missing state error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_Manual_AuthURL_RequireStateMissingCache(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
|
||||
t.Cleanup(func() {
|
||||
readClientCredentials = origRead
|
||||
oauthEndpoint = origEndpoint
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
oauthEndpoint = oauth2EndpointForTest("http://example.com")
|
||||
|
||||
_, err := Authorize(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
AuthURL: "http://localhost:1/?code=abc&state=state123",
|
||||
RequireState: true,
|
||||
Client: "default",
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
|
||||
if !errors.Is(err, errManualStateMissing) {
|
||||
t.Fatalf("expected manual state missing error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_Manual_AuthURL_RequireStateMissingForDifferentState(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
|
||||
t.Cleanup(func() {
|
||||
readClientCredentials = origRead
|
||||
oauthEndpoint = origEndpoint
|
||||
})
|
||||
useTempManualStatePath(t)
|
||||
|
||||
readClientCredentials = func(string) (config.ClientCredentials, error) {
|
||||
return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil
|
||||
}
|
||||
oauthEndpoint = oauth2EndpointForTest("http://example.com")
|
||||
|
||||
if err := saveManualState("default", []string{"s1"}, false, "state123"); err != nil {
|
||||
t.Fatalf("save manual state: %v", err)
|
||||
}
|
||||
|
||||
_, err := Authorize(context.Background(), AuthorizeOptions{
|
||||
Scopes: []string{"s1"},
|
||||
Manual: true,
|
||||
AuthURL: "http://localhost:1/?code=abc&state=DIFFERENT",
|
||||
RequireState: true,
|
||||
Client: "default",
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
|
||||
if !errors.Is(err, errManualStateMissing) {
|
||||
t.Fatalf("expected manual state missing error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_ServerFlow_Success(t *testing.T) {
|
||||
origRead := readClientCredentials
|
||||
origEndpoint := oauthEndpoint
|
||||
|
||||
Loading…
Reference in New Issue
Block a user