diff --git a/CHANGELOG.md b/CHANGELOG.md index 6228a38..3aa21f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ - Gmail: add `--exclude-labels` to `watch serve` (defaults: `SPAM,TRASH`). (#194) — thanks @salmonumbrella. - Drive: share files with an entire Workspace domain via `drive share --to domain`. (#192) — thanks @Danielkweber. + +### Fixed + +- Auth: improve remote/server-friendly manual OAuth flow (`auth add --remote`). (#187) — thanks @salmonumbrella. + ## 0.9.0 - 2026-01-22 ### Highlights diff --git a/README.md b/README.md index 40d01d1..d06683b 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,21 @@ gog auth add you@gmail.com This will open a browser window for OAuth authorization. The refresh token is stored securely in your system keychain. +Headless / remote server flow (no browser on the server): + +```bash +# Step 1: print auth URL (open it locally in a browser) +gog auth add you@gmail.com --services user --remote --step 1 + +# Step 2: paste the full redirect URL from your browser address bar +gog auth add you@gmail.com --services user --remote --step 2 --auth-url 'http://localhost:1/?code=...&state=...' +``` + +Notes: + +- The `state` is cached on disk for a short time (about 10 minutes). If it expires, rerun step 1. +- Remote step 2 requires a redirect URL that includes `state` (state check mandatory). + ### 4. Test Authentication ```bash diff --git a/docs/spec.md b/docs/spec.md index 20bf058..fa2e7a8 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -99,6 +99,9 @@ Implementation: `internal/secrets/store.go`. - Desktop OAuth 2.0 flow using local HTTP redirect on an ephemeral port. - Supports a browserless/manual flow (paste redirect URL) for headless environments. +- Supports a remote/server-friendly 2-step manual flow: + - Step 1 prints an auth URL (`gog auth add ... --remote --step 1`) + - Step 2 exchanges the pasted redirect URL and requires `state` validation (`--remote --step 2 --auth-url ...`) - Refresh token issuance: - requests `access_type=offline` - supports `--force-consent` to force the consent prompt when Google doesn't return a refresh token @@ -119,6 +122,7 @@ Scope selection note: - `credentials-.json` (OAuth client id/secret; named clients) - State: - `state/gmail-watch/.json` (Gmail watch state) + - `oauth-manual-state-.json` (temporary manual OAuth state cache; expires quickly; no tokens) - Secrets: - refresh tokens in keyring @@ -148,7 +152,7 @@ Flag aliases: - `gog auth credentials ` - `gog auth credentials list` - `gog --client auth credentials ` -- `gog auth add [--services user|all|gmail,calendar,classroom,drive,docs,contacts,tasks,sheets,people,groups] [--readonly] [--drive-scope full|readonly|file] [--manual] [--force-consent]` +- `gog auth add [--services user|all|gmail,calendar,classroom,drive,docs,contacts,tasks,sheets,people,groups] [--readonly] [--drive-scope full|readonly|file] [--manual] [--remote] [--step 1|2] [--auth-url URL] [--timeout DURATION] [--force-consent]` - `gog auth services [--markdown]` - `gog auth keep --key ` (Google Keep; Workspace only) - `gog auth list` diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index a615c20..07dbc98 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -26,6 +26,7 @@ var ( checkRefreshToken = googleauth.CheckRefreshToken ensureKeychainAccess = secrets.EnsureKeychainAccess fetchAuthorizedEmail = googleauth.EmailForRefreshToken + manualAuthURL = googleauth.ManualAuthURL ) func ensureKeychainAccessIfNeeded() error { @@ -479,12 +480,17 @@ func (c *AuthTokensImportCmd) Run(ctx context.Context) error { } type AuthAddCmd struct { - Email string `arg:"" name:"email" help:"Email"` - Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"` - ForceConsent bool `name:"force-consent" help:"Force consent screen to obtain a refresh token"` - ServicesCSV string `name:"services" help:"Services to authorize: user|all or comma-separated ${auth_services} (Keep uses service account: gog auth service-account set)" default:"user"` - Readonly bool `name:"readonly" help:"Use read-only scopes where available (still includes OIDC identity scopes)"` - DriveScope string `name:"drive-scope" help:"Drive scope mode: full|readonly|file" enum:"full,readonly,file" default:"full"` + Email string `arg:"" name:"email" help:"Email"` + Manual bool `name:"manual" help:"Browserless auth flow (paste redirect URL)"` + Remote bool `name:"remote" help:"Remote/server-friendly manual flow (print URL, then exchange code)"` + Step int `name:"step" help:"Remote auth step: 1=print URL, 2=exchange code"` + AuthURL string `name:"auth-url" help:"Redirect URL from browser (manual flow; required for --remote --step 2)"` + AuthCode string `name:"auth-code" hidden:"" help:"UNSAFE: Authorization code from browser (manual flow; skips state check; not valid with --remote)"` + Timeout time.Duration `name:"timeout" help:"Authorization timeout (manual flows default to 5m)"` + ForceConsent bool `name:"force-consent" help:"Force consent screen to obtain a refresh token"` + ServicesCSV string `name:"services" help:"Services to authorize: user|all or comma-separated ${auth_services} (Keep uses service account: gog auth service-account set)" default:"user"` + Readonly bool `name:"readonly" help:"Use read-only scopes where available (still includes OIDC identity scopes)"` + DriveScope string `name:"drive-scope" help:"Drive scope mode: full|readonly|file" enum:"full,readonly,file" default:"full"` } func (c *AuthAddCmd) Run(ctx context.Context) error { @@ -515,6 +521,69 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { return err } + authURL := strings.TrimSpace(c.AuthURL) + authCode := strings.TrimSpace(c.AuthCode) + if authURL != "" && authCode != "" { + return usage("cannot combine --auth-url with --auth-code") + } + if c.Step != 0 && c.Step != 1 && c.Step != 2 { + return usage("step must be 1 or 2") + } + if c.Step != 0 && !c.Remote { + return usage("--step requires --remote") + } + + manual := c.Manual || c.Remote || authURL != "" || authCode != "" + + if c.Remote { + step := c.Step + if step == 0 { + if authURL != "" || authCode != "" { + step = 2 + } else { + step = 1 + } + } + switch step { + case 1: + if authURL != "" || authCode != "" { + return usage("remote step 1 does not accept --auth-url or --auth-code") + } + result, manualErr := manualAuthURL(ctx, googleauth.AuthorizeOptions{ + Services: services, + Scopes: scopes, + Manual: true, + ForceConsent: c.ForceConsent, + Client: client, + }) + if manualErr != nil { + return manualErr + } + if outfmt.IsJSON(ctx) { + return outfmt.WriteJSON(os.Stdout, map[string]any{ + "auth_url": result.URL, + "state_reused": result.StateReused, + }) + } + u.Out().Printf("auth_url\t%s", result.URL) + u.Out().Printf("state_reused\t%t", result.StateReused) + u.Err().Println("Run again with --remote --step 2 --auth-url ") + return nil + case 2: + if authCode != "" { + return usage("--auth-code is not valid with --remote (state check is mandatory)") + } + if authURL == "" { + return usage("remote step 2 requires --auth-url") + } + } + } + + timeout := c.Timeout + if timeout == 0 && manual { + timeout = 5 * time.Minute + } + // Pre-flight: ensure keychain is accessible before starting OAuth if keychainErr := ensureKeychainAccessIfNeeded(); keychainErr != nil { return fmt.Errorf("keychain access: %w", keychainErr) @@ -523,9 +592,13 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { refreshToken, err := authorizeGoogle(ctx, googleauth.AuthorizeOptions{ Services: services, Scopes: scopes, - Manual: c.Manual, + Manual: manual, ForceConsent: c.ForceConsent, + Timeout: timeout, Client: client, + AuthURL: authURL, + AuthCode: authCode, + RequireState: c.Remote, }) if err != nil { return err diff --git a/internal/cmd/auth_add_test.go b/internal/cmd/auth_add_test.go index 88fb2aa..4febde5 100644 --- a/internal/cmd/auth_add_test.go +++ b/internal/cmd/auth_add_test.go @@ -473,6 +473,180 @@ func TestAuthAddCmd_SheetsDriveScopeFile(t *testing.T) { } } +func TestAuthAddCmd_RemoteStep1_PrintsAuthURL(t *testing.T) { + origManualURL := manualAuthURL + origAuth := authorizeGoogle + origKeychain := ensureKeychainAccess + t.Cleanup(func() { + manualAuthURL = origManualURL + authorizeGoogle = origAuth + ensureKeychainAccess = origKeychain + }) + + manualCalled := false + manualAuthURL = func(context.Context, googleauth.AuthorizeOptions) (googleauth.ManualAuthURLResult, error) { + manualCalled = true + return googleauth.ManualAuthURLResult{ + URL: "https://example.com/auth", + StateReused: true, + }, nil + } + authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { + t.Fatal("authorizeGoogle should not be called in remote step 1") + return "", nil + } + ensureKeychainAccess = func() error { + t.Fatal("keychain access should not be checked in remote step 1") + return nil + } + + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--remote", + "--step", + "1", + }); err != nil { + t.Fatalf("Execute: %v", err) + } + }) + }) + + if !manualCalled { + t.Fatalf("expected manualAuthURL to be called") + } + if !strings.Contains(out, "auth_url\thttps://example.com/auth") { + t.Fatalf("unexpected output: %q", out) + } + if !strings.Contains(out, "state_reused\ttrue") { + t.Fatalf("expected state_reused output, got: %q", out) + } +} + +func TestAuthAddCmd_RemoteStep2_RejectsAuthCode(t *testing.T) { + err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--remote", + "--step", + "2", + "--auth-code", + "abc123", + }) + if err == nil { + t.Fatalf("expected error") + } + var ee *ExitError + if !errors.As(err, &ee) || ee.Code != 2 { + t.Fatalf("expected exit code 2, got %T %#v", err, err) + } + if !strings.Contains(err.Error(), "--auth-code is not valid with --remote") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestAuthAddCmd_RemoteStep2_PassesAuthURL(t *testing.T) { + origAuth := authorizeGoogle + origOpen := openSecretsStore + origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail + t.Cleanup(func() { + authorizeGoogle = origAuth + openSecretsStore = origOpen + ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch + }) + + ensureKeychainAccess = func() error { return nil } + openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil } + + var gotOpts googleauth.AuthorizeOptions + authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) { + gotOpts = opts + return "rt", nil + } + fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) { + return "user@example.com", nil + } + + if err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--remote", + "--step", + "2", + "--auth-url", + "http://localhost:1/?code=abc&state=state123", + }); err != nil { + t.Fatalf("Execute: %v", err) + } + + if !gotOpts.Manual { + t.Fatalf("expected manual auth in remote step 2") + } + if !gotOpts.RequireState { + t.Fatalf("expected require state in remote step 2") + } + if gotOpts.AuthURL == "" { + t.Fatalf("expected auth URL to be passed through") + } +} + +func TestAuthAddCmd_AuthCode_PassesThrough(t *testing.T) { + origAuth := authorizeGoogle + origOpen := openSecretsStore + origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail + t.Cleanup(func() { + authorizeGoogle = origAuth + openSecretsStore = origOpen + ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch + }) + + ensureKeychainAccess = func() error { return nil } + openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil } + + var gotOpts googleauth.AuthorizeOptions + authorizeGoogle = func(ctx context.Context, opts googleauth.AuthorizeOptions) (string, error) { + gotOpts = opts + return "rt", nil + } + fetchAuthorizedEmail = func(context.Context, string, string, []string, time.Duration) (string, error) { + return "user@example.com", nil + } + + if err := Execute([]string{ + "auth", + "add", + "user@example.com", + "--services", + "gmail", + "--auth-code", + "abc123", + }); err != nil { + t.Fatalf("Execute: %v", err) + } + + if !gotOpts.Manual { + t.Fatalf("expected manual auth when auth-code is provided") + } + if gotOpts.AuthCode != "abc123" { + t.Fatalf("expected auth-code to be passed through, got %q", gotOpts.AuthCode) + } +} + func containsStringInSlice(items []string, want string) bool { for _, it := range items { if it == want { diff --git a/internal/cmd/gmail_send.go b/internal/cmd/gmail_send.go index 053ba1a..4b39386 100644 --- a/internal/cmd/gmail_send.go +++ b/internal/cmd/gmail_send.go @@ -119,15 +119,21 @@ func (c *GmailSendCmd) Run(ctx context.Context, flags *RootFlags) error { sendingEmail = c.From fromAddr = c.From // Include display name if set - if sa.DisplayName != "" { - fromAddr = sa.DisplayName + " <" + c.From + ">" + displayName := strings.TrimSpace(sa.DisplayName) + if displayName == "" { + if fallback, listErr := sendAsDisplayNameFromList(ctx, svc, c.From); listErr == nil { + displayName = fallback + } + } + if displayName != "" { + fromAddr = displayName + " <" + c.From + ">" } } else { // No --from specified: look up the primary account's send-as settings // to get the display name - sa, saErr := svc.Users.Settings.SendAs.Get("me", account).Context(ctx).Do() - if saErr == nil && sa.DisplayName != "" { - fromAddr = sa.DisplayName + " <" + account + ">" + displayName := primarySendAsDisplayName(ctx, svc, account) + if displayName != "" { + fromAddr = displayName + " <" + account + ">" } // If lookup fails, we just use the plain email address (no error) } @@ -217,6 +223,78 @@ func (c *GmailSendCmd) resolveTrackingConfig(account string, toRecipients, ccRec return trackingCfg, nil } +func primarySendAsDisplayName(ctx context.Context, svc *gmail.Service, account string) string { + account = strings.TrimSpace(account) + if account == "" || svc == nil { + return "" + } + + sa, err := svc.Users.Settings.SendAs.Get("me", account).Context(ctx).Do() + if err == nil { + if displayName := strings.TrimSpace(sa.DisplayName); displayName != "" { + return displayName + } + } + + displayName, err := primarySendAsDisplayNameFromList(ctx, svc, account) + if err != nil { + return "" + } + return displayName +} + +func sendAsDisplayNameFromList(ctx context.Context, svc *gmail.Service, email string) (string, error) { + email = strings.TrimSpace(email) + if email == "" || svc == nil { + return "", nil + } + + resp, err := svc.Users.Settings.SendAs.List("me").Context(ctx).Do() + if err != nil { + return "", err + } + + needle := strings.ToLower(email) + for _, sa := range resp.SendAs { + if strings.ToLower(strings.TrimSpace(sa.SendAsEmail)) == needle { + return strings.TrimSpace(sa.DisplayName), nil + } + } + + return "", nil +} + +func primarySendAsDisplayNameFromList(ctx context.Context, svc *gmail.Service, account string) (string, error) { + account = strings.TrimSpace(account) + if account == "" || svc == nil { + return "", nil + } + + resp, err := svc.Users.Settings.SendAs.List("me").Context(ctx).Do() + if err != nil { + return "", err + } + + needle := strings.ToLower(account) + var primary *gmail.SendAs + for _, sa := range resp.SendAs { + if sa.IsPrimary { + primary = sa + } + if strings.ToLower(strings.TrimSpace(sa.SendAsEmail)) == needle { + if displayName := strings.TrimSpace(sa.DisplayName); displayName != "" { + return displayName, nil + } + } + } + + if primary != nil { + return strings.TrimSpace(primary.DisplayName), nil + } + + return "", nil +} + func buildSendBatches(toRecipients, ccRecipients, bccRecipients []string, track, trackSplit bool) []sendBatch { totalRecipients := len(toRecipients) + len(ccRecipients) + len(bccRecipients) if track && trackSplit && totalRecipients > 1 { diff --git a/internal/cmd/gmail_send_test.go b/internal/cmd/gmail_send_test.go index 2f18ce9..e79620b 100644 --- a/internal/cmd/gmail_send_test.go +++ b/internal/cmd/gmail_send_test.go @@ -316,6 +316,82 @@ func TestGmailSendCmd_RunJSON_WithFrom(t *testing.T) { } } +func TestGmailSendCmd_RunJSON_WithFromDisplayNameFallbackToList(t *testing.T) { + origNew := newGmailService + t.Cleanup(func() { newGmailService = origNew }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/gmail/v1") + switch { + case r.Method == http.MethodGet && path == "/users/me/settings/sendAs/alias@example.com": + // Return send-as settings with empty display name but valid verification. + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sendAsEmail": "alias@example.com", + "displayName": "", + "verificationStatus": "accepted", + }) + return + case r.Method == http.MethodGet && path == "/users/me/settings/sendAs": + // Fallback list endpoint returns the alias with a populated display name. + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sendAs": []map[string]any{ + { + "sendAsEmail": "alias@example.com", + "displayName": "Alias From List", + "verificationStatus": "accepted", + }, + }, + }) + return + case r.Method == http.MethodPost && path == "/users/me/messages/send": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m2b", + "threadId": "t2b", + }) + return + default: + http.NotFound(w, r) + return + } + })) + defer srv.Close() + + svc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil } + + u, err := ui.New(ui.Options{Stdout: os.Stdout, Stderr: os.Stderr, Color: "never"}) + if err != nil { + t.Fatalf("ui.New: %v", err) + } + ctx := outfmt.WithMode(ui.WithUI(context.Background(), u), outfmt.Mode{JSON: true}) + + cmd := &GmailSendCmd{ + To: "a@example.com", + From: "alias@example.com", + Subject: "Hello", + Body: "Body", + } + + out := captureStdout(t, func() { + if err := cmd.Run(ctx, &RootFlags{Account: "a@b.com"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + if !strings.Contains(out, "\"from\"") || !strings.Contains(out, "Alias From List ") { + t.Fatalf("expected from with display name from list fallback, got: %q", out) + } +} + func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayName(t *testing.T) { origNew := newGmailService t.Cleanup(func() { newGmailService = origNew }) @@ -380,6 +456,81 @@ func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayName(t *testing.T) { } } +func TestGmailSendCmd_RunJSON_PrimaryAccountDisplayNameFallbackToList(t *testing.T) { + origNew := newGmailService + t.Cleanup(func() { newGmailService = origNew }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/gmail/v1") + switch { + case r.Method == http.MethodGet && path == "/users/me/settings/sendAs/a@b.com": + // Simulate missing display name in get response. + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sendAsEmail": "a@b.com", + "displayName": "", + "verificationStatus": "accepted", + }) + return + case r.Method == http.MethodGet && path == "/users/me/settings/sendAs": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sendAs": []map[string]any{ + { + "sendAsEmail": "a@b.com", + "displayName": "Primary User", + "verificationStatus": "accepted", + "isPrimary": true, + }, + }, + }) + return + case r.Method == http.MethodPost && path == "/users/me/messages/send": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m3b", + "threadId": "t3b", + }) + return + default: + http.NotFound(w, r) + return + } + })) + defer srv.Close() + + svc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil } + + u, err := ui.New(ui.Options{Stdout: os.Stdout, Stderr: os.Stderr, Color: "never"}) + if err != nil { + t.Fatalf("ui.New: %v", err) + } + ctx := outfmt.WithMode(ui.WithUI(context.Background(), u), outfmt.Mode{JSON: true}) + + cmd := &GmailSendCmd{ + To: "recipient@example.com", + Subject: "Hello", + Body: "Body", + } + + out := captureStdout(t, func() { + if err := cmd.Run(ctx, &RootFlags{Account: "a@b.com"}); err != nil { + t.Fatalf("Run: %v", err) + } + }) + if !strings.Contains(out, "\"from\"") || !strings.Contains(out, "Primary User ") { + t.Fatalf("expected from with display name, got: %q", out) + } +} + func TestGmailSendCmd_RunJSON_PrimaryAccountNoDisplayName(t *testing.T) { origNew := newGmailService t.Cleanup(func() { newGmailService = origNew }) diff --git a/internal/googleauth/manual_state.go b/internal/googleauth/manual_state.go new file mode 100644 index 0000000..6a5102e --- /dev/null +++ b/internal/googleauth/manual_state.go @@ -0,0 +1,234 @@ +package googleauth + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/steipete/gogcli/internal/config" +) + +const ( + manualStateFilePrefix = "oauth-manual-state-" + manualStateFileSuffix = ".json" +) + +var errEmptyManualAuthState = errors.New("empty manual auth state") + +// manualStateTTL controls how long a stored manual auth state is considered valid. +// This should be shorter than typical OAuth code expiration windows. +const manualStateTTL = 10 * time.Minute + +type manualState struct { + State string `json:"state"` + Client string `json:"client"` + Scopes []string `json:"scopes"` + ForceConsent bool `json:"force_consent,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +var ( + manualStateDirFn = manualStateDir + manualStateNowFn = time.Now +) + +func manualStateDir() (string, error) { + dir, err := config.EnsureDir() + if err != nil { + return "", fmt.Errorf("ensure config dir: %w", err) + } + + return dir, nil +} + +func manualStatePathFor(state string) (string, error) { + dir, err := manualStateDirFn() + if err != nil { + return "", err + } + + state = strings.TrimSpace(state) + if state == "" { + return "", errEmptyManualAuthState + } + + return filepath.Join(dir, manualStateFilePrefix+state+manualStateFileSuffix), nil +} + +func isManualStateFilename(name string) (state string, ok bool) { + if !strings.HasPrefix(name, manualStateFilePrefix) || !strings.HasSuffix(name, manualStateFileSuffix) { + return "", false + } + + state = strings.TrimSuffix(strings.TrimPrefix(name, manualStateFilePrefix), manualStateFileSuffix) + if strings.TrimSpace(state) == "" { + return "", false + } + + return state, true +} + +func loadManualState(client string, scopes []string, forceConsent bool) (string, bool, error) { + dir, err := manualStateDirFn() + if err != nil { + return "", false, err + } + + entries, err := os.ReadDir(dir) + if err != nil { + return "", false, fmt.Errorf("read manual auth state dir: %w", err) + } + + var ( + bestState string + bestCreated time.Time + ) + + for _, ent := range entries { + if ent.IsDir() { + continue + } + + state, ok := isManualStateFilename(ent.Name()) + if !ok { + continue + } + + path := filepath.Join(dir, ent.Name()) + + st, valid, loadErr := loadManualStateByPath(path) + if loadErr != nil { + return "", false, loadErr + } + + if !valid { + continue + } + + if st.Client != client || st.ForceConsent != forceConsent || !scopesEqual(st.Scopes, scopes) { + continue + } + + if bestState == "" || st.CreatedAt.After(bestCreated) { + bestState = state + bestCreated = st.CreatedAt + } + } + + if bestState == "" { + return "", false, nil + } + + return bestState, true, nil +} + +func loadManualStateByPath(path string) (manualState, bool, error) { + data, err := os.ReadFile(path) //nolint:gosec // config path + if err != nil { + if os.IsNotExist(err) { + return manualState{}, false, nil + } + + return manualState{}, false, fmt.Errorf("read manual auth state: %w", err) + } + + var st manualState + if err := json.Unmarshal(data, &st); err != nil { + _ = os.Remove(path) + return manualState{}, false, nil //nolint:nilerr // invalid state should be treated as a cache miss + } + + if st.State == "" { + _ = os.Remove(path) + return manualState{}, false, nil + } + + if manualStateNowFn().Sub(st.CreatedAt) > manualStateTTL { + _ = os.Remove(path) + return manualState{}, false, nil + } + + return st, true, nil +} + +func saveManualState(client string, scopes []string, forceConsent bool, state string) error { + path, err := manualStatePathFor(state) + if err != nil { + return err + } + + st := manualState{ + State: state, + Client: client, + Scopes: normalizeScopes(scopes), + ForceConsent: forceConsent, + CreatedAt: manualStateNowFn().UTC(), + } + + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return fmt.Errorf("encode manual auth state: %w", err) + } + + data = append(data, '\n') + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o600); err != nil { + return fmt.Errorf("write manual auth state: %w", err) + } + + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("commit manual auth state: %w", err) + } + + return nil +} + +func clearManualState(state string) error { + path, err := manualStatePathFor(state) + if err != nil { + return err + } + + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + return nil + } + + return fmt.Errorf("remove manual auth state: %w", err) + } + + return nil +} + +func normalizeScopes(scopes []string) []string { + if len(scopes) == 0 { + return nil + } + + out := append([]string(nil), scopes...) + sort.Strings(out) + + return out +} + +func scopesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + na := normalizeScopes(a) + nb := normalizeScopes(b) + + for i := range na { + if na[i] != nb[i] { + return false + } + } + + return true +} diff --git a/internal/googleauth/manual_state_test.go b/internal/googleauth/manual_state_test.go new file mode 100644 index 0000000..c690de3 --- /dev/null +++ b/internal/googleauth/manual_state_test.go @@ -0,0 +1,79 @@ +package googleauth + +import ( + "context" + "net/url" + "testing" + + "github.com/steipete/gogcli/internal/config" +) + +func TestManualAuthURL_ReusesState(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + }) + + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + stateCalls := 0 + randomStateFn = func() (string, error) { + stateCalls++ + if stateCalls == 1 { + return "state1", nil + } + + return "state2", nil + } + + res1, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + }) + if err != nil { + t.Fatalf("ManualAuthURL: %v", err) + } + + res2, err := ManualAuthURL(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + }) + if err != nil { + t.Fatalf("ManualAuthURL second: %v", err) + } + + state1 := authURLState(t, res1.URL) + + state2 := authURLState(t, res2.URL) + if state1 != "state1" || state2 != "state1" { + t.Fatalf("expected reused state, got state1=%q state2=%q", state1, state2) + } + + if !res2.StateReused { + t.Fatalf("expected state_reused true on second call") + } + + if stateCalls != 1 { + t.Fatalf("expected randomStateFn called once, got %d", stateCalls) + } +} + +func authURLState(t *testing.T, rawURL string) string { + t.Helper() + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse auth URL: %v", err) + } + + return parsed.Query().Get("state") +} diff --git a/internal/googleauth/oauth_flow.go b/internal/googleauth/oauth_flow.go index 0cea0b3..910e366 100644 --- a/internal/googleauth/oauth_flow.go +++ b/internal/googleauth/oauth_flow.go @@ -29,6 +29,14 @@ type AuthorizeOptions struct { ForceConsent bool Timeout time.Duration Client string + AuthCode string + AuthURL string + RequireState bool +} + +type ManualAuthURLResult struct { + URL string + StateReused bool } // postSuccessDisplaySeconds is the number of seconds the success page remains @@ -51,12 +59,18 @@ var ( ) var ( - errAuthorization = errors.New("authorization error") - errMissingCode = errors.New("missing code") - errMissingScopes = errors.New("missing scopes") - errNoCodeInURL = errors.New("no code found in URL") - errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent") - errStateMismatch = errors.New("state mismatch") + errAuthorization = errors.New("authorization error") + errMissingCode = errors.New("missing code") + errMissingState = errors.New("missing state in redirect URL") + errMissingScopes = errors.New("missing scopes") + errNoCodeInURL = errors.New("no code found in URL") + errNoRefreshToken = errors.New("no refresh token received; try again with --force-consent") + errManualStateMissing = errors.New("manual auth state missing; run remote step 1 again") + errManualStateMismatch = errors.New("manual auth state mismatch; run remote step 1 again") + errStateMismatch = errors.New("state mismatch") + + errInvalidAuthorizeOptionsAuthURLAndCode = errors.New("cannot combine auth-url with auth-code") + errInvalidAuthorizeOptionsAuthCodeWithState = errors.New("auth-code is not valid when state is required; provide auth-url") ) func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) { @@ -64,19 +78,19 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) { opts.Timeout = 2 * time.Minute } + if strings.TrimSpace(opts.AuthURL) != "" && strings.TrimSpace(opts.AuthCode) != "" { + return "", errInvalidAuthorizeOptionsAuthURLAndCode + } + + if opts.RequireState && strings.TrimSpace(opts.AuthCode) != "" { + return "", errInvalidAuthorizeOptionsAuthCodeWithState + } + if len(opts.Scopes) == 0 { return "", errMissingScopes } - var creds config.ClientCredentials - - if c, err := readClientCredentials(opts.Client); err != nil { - return "", err - } else { - creds = c - } - - state, err := randomStateFn() + creds, err := readClientCredentials(opts.Client) if err != nil { return "", err } @@ -85,55 +99,173 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) { defer cancel() if opts.Manual { - redirectURI := "http://localhost:1" - cfg := oauth2.Config{ - ClientID: creds.ClientID, - ClientSecret: creds.ClientSecret, - Endpoint: oauthEndpoint, - RedirectURL: redirectURI, - Scopes: opts.Scopes, - } - authURL := cfg.AuthCodeURL(state, authURLParams(opts.ForceConsent)...) + return authorizeManual(ctx, opts, creds) + } - fmt.Fprintln(os.Stderr, "Visit this URL to authorize:") - fmt.Fprintln(os.Stderr, authURL) - fmt.Fprintln(os.Stderr) - fmt.Fprintln(os.Stderr, "After authorizing, you'll be redirected to a localhost URL that won't load.") - fmt.Fprintln(os.Stderr, "Copy the URL from your browser's address bar and paste it here.") - fmt.Fprintln(os.Stderr) + return authorizeServer(ctx, opts, creds) +} - line, readErr := input.PromptLine(ctx, "Paste redirect URL (Enter or Ctrl-D): ") - if readErr != nil && !errors.Is(readErr, os.ErrClosed) { - if errors.Is(readErr, io.EOF) { - return "", fmt.Errorf("authorization canceled: %w", context.Canceled) - } +func authorizeManual(ctx context.Context, opts AuthorizeOptions, creds config.ClientCredentials) (string, error) { + authURLInput := strings.TrimSpace(opts.AuthURL) + authCodeInput := strings.TrimSpace(opts.AuthCode) - return "", fmt.Errorf("read redirect url: %w", readErr) - } - line = strings.TrimSpace(line) + redirectURI := "http://localhost:1" + cfg := oauth2.Config{ + ClientID: creds.ClientID, + ClientSecret: creds.ClientSecret, + Endpoint: oauthEndpoint, + RedirectURL: redirectURI, + Scopes: opts.Scopes, + } - code, gotState, parseErr := extractCodeAndState(line) + if authURLInput != "" || authCodeInput != "" { + return authorizeManualWithCode(ctx, opts, cfg, authURLInput, authCodeInput) + } + + return authorizeManualInteractive(ctx, opts, cfg) +} + +func authorizeManualWithCode( + ctx context.Context, + opts AuthorizeOptions, + cfg oauth2.Config, + authURLInput string, + authCodeInput string, +) (string, error) { + code := strings.TrimSpace(authCodeInput) + gotState := "" + + if authURLInput != "" { + parsedCode, parsedState, parseErr := extractCodeAndState(authURLInput) if parseErr != nil { return "", parseErr } - if gotState != "" && gotState != state { - return "", errStateMismatch + code = parsedCode + gotState = parsedState + + if opts.RequireState && gotState == "" { + return "", errMissingState + } + } + + if strings.TrimSpace(code) == "" { + return "", errMissingCode + } + + if gotState != "" { + if err := validateManualState(opts, gotState); err != nil { + return "", err + } + } + + tok, exchangeErr := cfg.Exchange(ctx, code) + if exchangeErr != nil { + return "", fmt.Errorf("exchange code: %w", exchangeErr) + } + + if tok.RefreshToken == "" { + return "", errNoRefreshToken + } + + if gotState != "" { + _ = clearManualState(gotState) + } + + return tok.RefreshToken, nil +} + +func authorizeManualInteractive(ctx context.Context, opts AuthorizeOptions, cfg oauth2.Config) (string, error) { + setup, err := manualAuthSetup(opts, cfg) + if err != nil { + return "", err + } + + fmt.Fprintln(os.Stderr, "Visit this URL to authorize:") + fmt.Fprintln(os.Stderr, setup.authURL) + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "After authorizing, you'll be redirected to a localhost URL that won't load.") + fmt.Fprintln(os.Stderr, "Copy the URL from your browser's address bar and paste it here.") + fmt.Fprintln(os.Stderr) + + line, readErr := input.PromptLine(ctx, "Paste redirect URL (Enter or Ctrl-D): ") + if readErr != nil && !errors.Is(readErr, os.ErrClosed) { + if errors.Is(readErr, io.EOF) { + return "", fmt.Errorf("authorization canceled: %w", context.Canceled) } - var tok *oauth2.Token + return "", fmt.Errorf("read redirect url: %w", readErr) + } - if t, exchangeErr := cfg.Exchange(ctx, code); exchangeErr != nil { - return "", fmt.Errorf("exchange code: %w", exchangeErr) - } else { - tok = t + line = strings.TrimSpace(line) + + code, gotState, parseErr := extractCodeAndState(line) + if parseErr != nil { + return "", parseErr + } + + if gotState != "" && gotState != setup.state { + return "", errStateMismatch + } + + tok, exchangeErr := cfg.Exchange(ctx, code) + if exchangeErr != nil { + return "", fmt.Errorf("exchange code: %w", exchangeErr) + } + + if tok.RefreshToken == "" { + return "", errNoRefreshToken + } + + _ = clearManualState(setup.state) + + return tok.RefreshToken, nil +} + +func validateManualState(opts AuthorizeOptions, gotState string) error { + if opts.RequireState { + if gotState == "" { + return errMissingState + } + } + + if gotState == "" { + return nil + } + + path, err := manualStatePathFor(gotState) + if err != nil { + return err + } + + st, ok, err := loadManualStateByPath(path) + if err != nil { + return err + } + + if !ok { + if opts.RequireState { + return errManualStateMissing } - if tok.RefreshToken == "" { - return "", errNoRefreshToken + return nil + } + + if st.Client != opts.Client || st.ForceConsent != opts.ForceConsent || !scopesEqual(st.Scopes, opts.Scopes) { + if opts.RequireState { + return errManualStateMismatch } - return tok.RefreshToken, nil + return errStateMismatch + } + + return nil +} + +func authorizeServer(ctx context.Context, opts AuthorizeOptions, creds config.ClientCredentials) (string, error) { + state, err := randomStateFn() + if err != nil { + return "", err } ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") @@ -270,6 +402,70 @@ func Authorize(ctx context.Context, opts AuthorizeOptions) (string, error) { } } +func ManualAuthURL(ctx context.Context, opts AuthorizeOptions) (ManualAuthURLResult, error) { + if opts.Timeout <= 0 { + opts.Timeout = 2 * time.Minute + } + + if len(opts.Scopes) == 0 { + return ManualAuthURLResult{}, errMissingScopes + } + + creds, err := readClientCredentials(opts.Client) + if err != nil { + return ManualAuthURLResult{}, err + } + + redirectURI := "http://localhost:1" + cfg := oauth2.Config{ + ClientID: creds.ClientID, + ClientSecret: creds.ClientSecret, + Endpoint: oauthEndpoint, + RedirectURL: redirectURI, + Scopes: opts.Scopes, + } + + setup, err := manualAuthSetup(opts, cfg) + if err != nil { + return ManualAuthURLResult{}, err + } + + return ManualAuthURLResult{ + URL: setup.authURL, + StateReused: setup.reused, + }, nil +} + +type manualAuthSetupResult struct { + authURL string + state string + reused bool +} + +func manualAuthSetup(opts AuthorizeOptions, cfg oauth2.Config) (manualAuthSetupResult, error) { + state, reused, err := loadManualState(opts.Client, opts.Scopes, opts.ForceConsent) + if err != nil { + return manualAuthSetupResult{}, err + } + + if !reused { + state, err = randomStateFn() + if err != nil { + return manualAuthSetupResult{}, err + } + + if err := saveManualState(opts.Client, opts.Scopes, opts.ForceConsent, state); err != nil { + return manualAuthSetupResult{}, err + } + } + + return manualAuthSetupResult{ + authURL: cfg.AuthCodeURL(state, authURLParams(opts.ForceConsent)...), + state: state, + reused: reused, + }, nil +} + func authURLParams(forceConsent bool) []oauth2.AuthCodeOption { opts := []oauth2.AuthCodeOption{ oauth2.AccessTypeOffline, diff --git a/internal/googleauth/oauth_flow_authorize_test.go b/internal/googleauth/oauth_flow_authorize_test.go index 62a8d1d..12e2905 100644 --- a/internal/googleauth/oauth_flow_authorize_test.go +++ b/internal/googleauth/oauth_flow_authorize_test.go @@ -54,6 +54,18 @@ func newTokenServer(t *testing.T) *httptest.Server { })) } +func useTempManualStatePath(t *testing.T) { + t.Helper() + + origDir := manualStateDirFn + dir := t.TempDir() + manualStateDirFn = func() (string, error) { return dir, nil } + + t.Cleanup(func() { + manualStateDirFn = origDir + }) +} + func TestAuthorize_MissingScopes(t *testing.T) { _, err := Authorize(context.Background(), AuthorizeOptions{}) if err == nil || !strings.Contains(err.Error(), "missing scopes") { @@ -72,6 +84,8 @@ func TestAuthorize_Manual_Success(t *testing.T) { randomStateFn = origState }) + useTempManualStatePath(t) + readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil } @@ -122,6 +136,7 @@ func TestAuthorize_Manual_Success_NoNewline(t *testing.T) { oauthEndpoint = origEndpoint randomStateFn = origState }) + useTempManualStatePath(t) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil @@ -173,6 +188,7 @@ func TestAuthorize_Manual_CancelEOF(t *testing.T) { oauthEndpoint = origEndpoint randomStateFn = origState }) + useTempManualStatePath(t) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil @@ -218,6 +234,7 @@ func TestAuthorize_Manual_StateMismatch(t *testing.T) { oauthEndpoint = origEndpoint randomStateFn = origState }) + useTempManualStatePath(t) readClientCredentials = func(string) (config.ClientCredentials, error) { return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil @@ -254,6 +271,150 @@ func TestAuthorize_Manual_StateMismatch(t *testing.T) { } } +func TestAuthorize_Manual_AuthCode(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + origState := randomStateFn + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + randomStateFn = origState + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + stateCalled := false + randomStateFn = func() (string, error) { + stateCalled = true + return "state123", nil + } + + tokenSrv := newTokenServer(t) + defer tokenSrv.Close() + oauthEndpoint = oauth2EndpointForTest(tokenSrv.URL) + + rt, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthCode: "abc", + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("Authorize: %v", err) + } + + if rt != "rt" { + t.Fatalf("unexpected refresh token: %q", rt) + } + + if stateCalled { + t.Fatalf("unexpected state generation in auth-code flow") + } +} + +func TestAuthorize_Manual_AuthURL_RequireStateMissing(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + + _, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthURL: "http://localhost:1/?code=abc", + RequireState: true, + Client: "default", + Timeout: 2 * time.Second, + }) + if err == nil { + t.Fatalf("expected error") + } + + if !errors.Is(err, errMissingState) { + t.Fatalf("expected missing state error, got: %v", err) + } +} + +func TestAuthorize_Manual_AuthURL_RequireStateMissingCache(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + + _, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthURL: "http://localhost:1/?code=abc&state=state123", + RequireState: true, + Client: "default", + Timeout: 2 * time.Second, + }) + if err == nil { + t.Fatalf("expected error") + } + + if !errors.Is(err, errManualStateMissing) { + t.Fatalf("expected manual state missing error, got: %v", err) + } +} + +func TestAuthorize_Manual_AuthURL_RequireStateMissingForDifferentState(t *testing.T) { + origRead := readClientCredentials + origEndpoint := oauthEndpoint + + t.Cleanup(func() { + readClientCredentials = origRead + oauthEndpoint = origEndpoint + }) + useTempManualStatePath(t) + + readClientCredentials = func(string) (config.ClientCredentials, error) { + return config.ClientCredentials{ClientID: "id", ClientSecret: "secret"}, nil + } + oauthEndpoint = oauth2EndpointForTest("http://example.com") + + if err := saveManualState("default", []string{"s1"}, false, "state123"); err != nil { + t.Fatalf("save manual state: %v", err) + } + + _, err := Authorize(context.Background(), AuthorizeOptions{ + Scopes: []string{"s1"}, + Manual: true, + AuthURL: "http://localhost:1/?code=abc&state=DIFFERENT", + RequireState: true, + Client: "default", + Timeout: 2 * time.Second, + }) + if err == nil { + t.Fatalf("expected error") + } + + if !errors.Is(err, errManualStateMissing) { + t.Fatalf("expected manual state missing error, got: %v", err) + } +} + func TestAuthorize_ServerFlow_Success(t *testing.T) { origRead := readClientCredentials origEndpoint := oauthEndpoint