diff --git a/internal/cmd/auth.go b/internal/cmd/auth.go index 0c4cfa7..a7d0ddd 100644 --- a/internal/cmd/auth.go +++ b/internal/cmd/auth.go @@ -24,6 +24,7 @@ var ( startManageServer = googleauth.StartManageServer checkRefreshToken = googleauth.CheckRefreshToken ensureKeychainAccess = secrets.EnsureKeychainAccess + fetchAuthorizedEmail = googleauth.EmailForRefreshToken ) func ensureKeychainAccessIfNeeded() error { @@ -37,6 +38,10 @@ func ensureKeychainAccessIfNeeded() error { return ensureKeychainAccess() } +func normalizeEmail(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + type AuthCmd struct { Credentials AuthCredentialsCmd `cmd:"" name:"credentials" help:"Store OAuth client credentials"` Add AuthAddCmd `cmd:"" name:"add" help:"Authorize and store a refresh token"` @@ -337,7 +342,7 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { return fmt.Errorf("no services selected") } - scopes, err := googleauth.ScopesForServices(services) + scopes, err := googleauth.ScopesForManage(services) if err != nil { return err } @@ -357,6 +362,14 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { return err } + authorizedEmail, err := fetchAuthorizedEmail(ctx, refreshToken, scopes, 15*time.Second) + if err != nil { + return fmt.Errorf("fetch authorized email: %w", err) + } + if normalizeEmail(authorizedEmail) != normalizeEmail(c.Email) { + return fmt.Errorf("authorized as %s, expected %s", authorizedEmail, c.Email) + } + store, err := openSecretsStore() if err != nil { return err @@ -367,8 +380,8 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { } sort.Strings(serviceNames) - if err := store.SetToken(c.Email, secrets.Token{ - Email: c.Email, + if err := store.SetToken(authorizedEmail, secrets.Token{ + Email: authorizedEmail, Services: serviceNames, Scopes: scopes, RefreshToken: refreshToken, @@ -378,11 +391,11 @@ func (c *AuthAddCmd) Run(ctx context.Context) error { if outfmt.IsJSON(ctx) { return outfmt.WriteJSON(os.Stdout, map[string]any{ "stored": true, - "email": c.Email, + "email": authorizedEmail, "services": serviceNames, }) } - u.Out().Printf("email\t%s", c.Email) + u.Out().Printf("email\t%s", authorizedEmail) u.Out().Printf("services\t%s", strings.Join(serviceNames, ",")) return nil } diff --git a/internal/cmd/auth_add_keep_more_test.go b/internal/cmd/auth_add_keep_more_test.go index 7ed9c2b..154a29f 100644 --- a/internal/cmd/auth_add_keep_more_test.go +++ b/internal/cmd/auth_add_keep_more_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/steipete/gogcli/internal/googleauth" "github.com/steipete/gogcli/internal/outfmt" @@ -19,10 +20,12 @@ func TestAuthAddCmd_JSON_More(t *testing.T) { origOpen := openSecretsStore origAuth := authorizeGoogle origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { openSecretsStore = origOpen authorizeGoogle = origAuth ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) store := newMemSecretsStore() @@ -33,6 +36,9 @@ func TestAuthAddCmd_JSON_More(t *testing.T) { } return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + return "a@b.com", nil + } ensureKeychainAccess = func() error { return nil } u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"}) diff --git a/internal/cmd/auth_add_test.go b/internal/cmd/auth_add_test.go index bb5a13e..042021a 100644 --- a/internal/cmd/auth_add_test.go +++ b/internal/cmd/auth_add_test.go @@ -6,6 +6,7 @@ import ( "errors" "strings" "testing" + "time" "github.com/steipete/gogcli/internal/googleauth" "github.com/steipete/gogcli/internal/secrets" @@ -15,10 +16,12 @@ func TestAuthAddCmd_JSON(t *testing.T) { origAuth := authorizeGoogle origOpen := openSecretsStore origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { authorizeGoogle = origAuth openSecretsStore = origOpen ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) ensureKeychainAccess = func() error { return nil } @@ -31,6 +34,9 @@ func TestAuthAddCmd_JSON(t *testing.T) { gotOpts = opts return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + return "user@example.com", nil + } out := captureStdout(t, func() { _ = captureStderr(t, func() { @@ -80,10 +86,12 @@ func TestAuthAddCmd_KeychainError(t *testing.T) { origAuth := authorizeGoogle origOpen := openSecretsStore origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { authorizeGoogle = origAuth openSecretsStore = origOpen ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) // Simulate keychain locked error @@ -96,6 +104,10 @@ func TestAuthAddCmd_KeychainError(t *testing.T) { authCalled = true return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + t.Fatal("fetchAuthorizedEmail should not be called when keychain check fails") + return "", nil + } store := newMemSecretsStore() openSecretsStore = func() (secrets.Store, error) { return store, nil } @@ -118,10 +130,12 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) { origAuth := authorizeGoogle origOpen := openSecretsStore origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { authorizeGoogle = origAuth openSecretsStore = origOpen ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) ensureKeychainAccess = func() error { return nil } @@ -134,6 +148,9 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) { gotOpts = opts return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + return "user@example.com", nil + } _ = captureStdout(t, func() { _ = captureStderr(t, func() { @@ -179,3 +196,33 @@ func TestAuthAddCmd_KeepRejected(t *testing.T) { t.Fatalf("authorizeGoogle should not be called") } } + +func TestAuthAddCmd_EmailMismatch(t *testing.T) { + origAuth := authorizeGoogle + origOpen := openSecretsStore + origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail + t.Cleanup(func() { + authorizeGoogle = origAuth + openSecretsStore = origOpen + ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch + }) + + ensureKeychainAccess = func() error { return nil } + openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil } + authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { + return "rt", nil + } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + return "actual@example.com", nil + } + + err := Execute([]string{"auth", "add", "expected@example.com"}) + if err == nil { + t.Fatalf("expected mismatch error") + } + if !strings.Contains(err.Error(), "authorized as actual@example.com") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/cmd/auth_cmd_test.go b/internal/cmd/auth_cmd_test.go index 4bfed9a..5dcf740 100644 --- a/internal/cmd/auth_cmd_test.go +++ b/internal/cmd/auth_cmd_test.go @@ -25,7 +25,7 @@ func newMemSecretsStore() *memSecretsStore { return &memSecretsStore{tokens: make(map[string]secrets.Token)} } -func normalizeEmail(s string) string { +func normalizeEmailTest(s string) string { return strings.ToLower(strings.TrimSpace(s)) } @@ -39,7 +39,7 @@ func (s *memSecretsStore) Keys() ([]string, error) { } func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error { - email = normalizeEmail(email) + email = normalizeEmailTest(email) if email == "" { return errors.New("missing email") } @@ -52,7 +52,7 @@ func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error { } func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) { - email = normalizeEmail(email) + email = normalizeEmailTest(email) if email == "" { return secrets.Token{}, errors.New("missing email") } @@ -63,7 +63,7 @@ func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) { } func (s *memSecretsStore) DeleteToken(email string) error { - email = normalizeEmail(email) + email = normalizeEmailTest(email) if email == "" { return errors.New("missing email") } diff --git a/internal/cmd/auth_validation_more_test.go b/internal/cmd/auth_validation_more_test.go index 81c8cd2..2e20b35 100644 --- a/internal/cmd/auth_validation_more_test.go +++ b/internal/cmd/auth_validation_more_test.go @@ -231,15 +231,18 @@ func TestAuthAdd_TextOutput(t *testing.T) { origOpen := openSecretsStore origAuth := authorizeGoogle origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { openSecretsStore = origOpen authorizeGoogle = origAuth ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) store := newMemStore() openSecretsStore = func() (secrets.Store, error) { return store, nil } authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { return "a@b.com", nil } ensureKeychainAccess = func() error { return nil } var outBuf strings.Builder diff --git a/internal/cmd/execute_auth_add_test.go b/internal/cmd/execute_auth_add_test.go index 446312a..89f71dc 100644 --- a/internal/cmd/execute_auth_add_test.go +++ b/internal/cmd/execute_auth_add_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "testing" + "time" "github.com/steipete/gogcli/internal/googleauth" "github.com/steipete/gogcli/internal/secrets" @@ -13,10 +14,12 @@ func TestExecute_AuthAdd_JSON(t *testing.T) { origOpen := openSecretsStore origAuth := authorizeGoogle origKeychain := ensureKeychainAccess + origFetch := fetchAuthorizedEmail t.Cleanup(func() { openSecretsStore = origOpen authorizeGoogle = origAuth ensureKeychainAccess = origKeychain + fetchAuthorizedEmail = origFetch }) ensureKeychainAccess = func() error { return nil } @@ -31,6 +34,9 @@ func TestExecute_AuthAdd_JSON(t *testing.T) { gotOpts.Scopes = append([]string{}, opts.Scopes...) return "rt", nil } + fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { + return "a@b.com", nil + } out := captureStdout(t, func() { _ = captureStderr(t, func() { diff --git a/internal/googleauth/token_email.go b/internal/googleauth/token_email.go new file mode 100644 index 0000000..8496235 --- /dev/null +++ b/internal/googleauth/token_email.go @@ -0,0 +1,56 @@ +package googleauth + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "golang.org/x/oauth2" +) + +// EmailForRefreshToken exchanges a refresh token and returns the authorized email address. +func EmailForRefreshToken(ctx context.Context, refreshToken string, scopes []string, timeout time.Duration) (string, error) { + if strings.TrimSpace(refreshToken) == "" { + return "", errMissingToken + } + if timeout <= 0 { + timeout = 15 * time.Second + } + + creds, err := readClientCredentials() + if err != nil { + return "", fmt.Errorf("read credentials: %w", err) + } + + cfg := oauth2.Config{ + ClientID: creds.ClientID, + ClientSecret: creds.ClientSecret, + Endpoint: oauthEndpoint, + Scopes: scopes, + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: timeout}) + + ts := cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}) + tok, err := ts.Token() + if err != nil { + return "", fmt.Errorf("refresh access token: %w", err) + } + + if raw, ok := tok.Extra("id_token").(string); ok { + if email, err := emailFromIDToken(raw); err == nil { + return email, nil + } + } + + if strings.TrimSpace(tok.AccessToken) == "" { + return "", errMissingAccessToken + } + + return fetchUserEmailWithURL(ctx, tok.AccessToken, userinfoURL) +}