fix(auth): verify account matches authorized email

This commit is contained in:
Peter Steinberger 2026-01-09 17:30:19 +01:00
parent b14c0ce908
commit 4af0cb83f1
7 changed files with 140 additions and 9 deletions

View File

@ -24,6 +24,7 @@ var (
startManageServer = googleauth.StartManageServer
checkRefreshToken = googleauth.CheckRefreshToken
ensureKeychainAccess = secrets.EnsureKeychainAccess
fetchAuthorizedEmail = googleauth.EmailForRefreshToken
)
func ensureKeychainAccessIfNeeded() error {
@ -37,6 +38,10 @@ func ensureKeychainAccessIfNeeded() error {
return ensureKeychainAccess()
}
func normalizeEmail(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
type AuthCmd struct {
Credentials AuthCredentialsCmd `cmd:"" name:"credentials" help:"Store OAuth client credentials"`
Add AuthAddCmd `cmd:"" name:"add" help:"Authorize and store a refresh token"`
@ -337,7 +342,7 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
return fmt.Errorf("no services selected")
}
scopes, err := googleauth.ScopesForServices(services)
scopes, err := googleauth.ScopesForManage(services)
if err != nil {
return err
}
@ -357,6 +362,14 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
return err
}
authorizedEmail, err := fetchAuthorizedEmail(ctx, refreshToken, scopes, 15*time.Second)
if err != nil {
return fmt.Errorf("fetch authorized email: %w", err)
}
if normalizeEmail(authorizedEmail) != normalizeEmail(c.Email) {
return fmt.Errorf("authorized as %s, expected %s", authorizedEmail, c.Email)
}
store, err := openSecretsStore()
if err != nil {
return err
@ -367,8 +380,8 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
}
sort.Strings(serviceNames)
if err := store.SetToken(c.Email, secrets.Token{
Email: c.Email,
if err := store.SetToken(authorizedEmail, secrets.Token{
Email: authorizedEmail,
Services: serviceNames,
Scopes: scopes,
RefreshToken: refreshToken,
@ -378,11 +391,11 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
if outfmt.IsJSON(ctx) {
return outfmt.WriteJSON(os.Stdout, map[string]any{
"stored": true,
"email": c.Email,
"email": authorizedEmail,
"services": serviceNames,
})
}
u.Out().Printf("email\t%s", c.Email)
u.Out().Printf("email\t%s", authorizedEmail)
u.Out().Printf("services\t%s", strings.Join(serviceNames, ","))
return nil
}

View File

@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/steipete/gogcli/internal/googleauth"
"github.com/steipete/gogcli/internal/outfmt"
@ -19,10 +20,12 @@ func TestAuthAddCmd_JSON_More(t *testing.T) {
origOpen := openSecretsStore
origAuth := authorizeGoogle
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
openSecretsStore = origOpen
authorizeGoogle = origAuth
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
store := newMemSecretsStore()
@ -33,6 +36,9 @@ func TestAuthAddCmd_JSON_More(t *testing.T) {
}
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
return "a@b.com", nil
}
ensureKeychainAccess = func() error { return nil }
u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"})

View File

@ -6,6 +6,7 @@ import (
"errors"
"strings"
"testing"
"time"
"github.com/steipete/gogcli/internal/googleauth"
"github.com/steipete/gogcli/internal/secrets"
@ -15,10 +16,12 @@ func TestAuthAddCmd_JSON(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
ensureKeychainAccess = func() error { return nil }
@ -31,6 +34,9 @@ func TestAuthAddCmd_JSON(t *testing.T) {
gotOpts = opts
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
return "user@example.com", nil
}
out := captureStdout(t, func() {
_ = captureStderr(t, func() {
@ -80,10 +86,12 @@ func TestAuthAddCmd_KeychainError(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
// Simulate keychain locked error
@ -96,6 +104,10 @@ func TestAuthAddCmd_KeychainError(t *testing.T) {
authCalled = true
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
t.Fatal("fetchAuthorizedEmail should not be called when keychain check fails")
return "", nil
}
store := newMemSecretsStore()
openSecretsStore = func() (secrets.Store, error) { return store, nil }
@ -118,10 +130,12 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
ensureKeychainAccess = func() error { return nil }
@ -134,6 +148,9 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) {
gotOpts = opts
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
return "user@example.com", nil
}
_ = captureStdout(t, func() {
_ = captureStderr(t, func() {
@ -179,3 +196,33 @@ func TestAuthAddCmd_KeepRejected(t *testing.T) {
t.Fatalf("authorizeGoogle should not be called")
}
}
func TestAuthAddCmd_EmailMismatch(t *testing.T) {
origAuth := authorizeGoogle
origOpen := openSecretsStore
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
authorizeGoogle = origAuth
openSecretsStore = origOpen
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
ensureKeychainAccess = func() error { return nil }
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) {
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
return "actual@example.com", nil
}
err := Execute([]string{"auth", "add", "expected@example.com"})
if err == nil {
t.Fatalf("expected mismatch error")
}
if !strings.Contains(err.Error(), "authorized as actual@example.com") {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@ -25,7 +25,7 @@ func newMemSecretsStore() *memSecretsStore {
return &memSecretsStore{tokens: make(map[string]secrets.Token)}
}
func normalizeEmail(s string) string {
func normalizeEmailTest(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
@ -39,7 +39,7 @@ func (s *memSecretsStore) Keys() ([]string, error) {
}
func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error {
email = normalizeEmail(email)
email = normalizeEmailTest(email)
if email == "" {
return errors.New("missing email")
}
@ -52,7 +52,7 @@ func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error {
}
func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) {
email = normalizeEmail(email)
email = normalizeEmailTest(email)
if email == "" {
return secrets.Token{}, errors.New("missing email")
}
@ -63,7 +63,7 @@ func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) {
}
func (s *memSecretsStore) DeleteToken(email string) error {
email = normalizeEmail(email)
email = normalizeEmailTest(email)
if email == "" {
return errors.New("missing email")
}

View File

@ -231,15 +231,18 @@ func TestAuthAdd_TextOutput(t *testing.T) {
origOpen := openSecretsStore
origAuth := authorizeGoogle
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
openSecretsStore = origOpen
authorizeGoogle = origAuth
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
store := newMemStore()
openSecretsStore = func() (secrets.Store, error) { return store, nil }
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { return "rt", nil }
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { return "a@b.com", nil }
ensureKeychainAccess = func() error { return nil }
var outBuf strings.Builder

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"testing"
"time"
"github.com/steipete/gogcli/internal/googleauth"
"github.com/steipete/gogcli/internal/secrets"
@ -13,10 +14,12 @@ func TestExecute_AuthAdd_JSON(t *testing.T) {
origOpen := openSecretsStore
origAuth := authorizeGoogle
origKeychain := ensureKeychainAccess
origFetch := fetchAuthorizedEmail
t.Cleanup(func() {
openSecretsStore = origOpen
authorizeGoogle = origAuth
ensureKeychainAccess = origKeychain
fetchAuthorizedEmail = origFetch
})
ensureKeychainAccess = func() error { return nil }
@ -31,6 +34,9 @@ func TestExecute_AuthAdd_JSON(t *testing.T) {
gotOpts.Scopes = append([]string{}, opts.Scopes...)
return "rt", nil
}
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
return "a@b.com", nil
}
out := captureStdout(t, func() {
_ = captureStderr(t, func() {

View File

@ -0,0 +1,56 @@
package googleauth
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/oauth2"
)
// EmailForRefreshToken exchanges a refresh token and returns the authorized email address.
func EmailForRefreshToken(ctx context.Context, refreshToken string, scopes []string, timeout time.Duration) (string, error) {
if strings.TrimSpace(refreshToken) == "" {
return "", errMissingToken
}
if timeout <= 0 {
timeout = 15 * time.Second
}
creds, err := readClientCredentials()
if err != nil {
return "", fmt.Errorf("read credentials: %w", err)
}
cfg := oauth2.Config{
ClientID: creds.ClientID,
ClientSecret: creds.ClientSecret,
Endpoint: oauthEndpoint,
Scopes: scopes,
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: timeout})
ts := cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken})
tok, err := ts.Token()
if err != nil {
return "", fmt.Errorf("refresh access token: %w", err)
}
if raw, ok := tok.Extra("id_token").(string); ok {
if email, err := emailFromIDToken(raw); err == nil {
return email, nil
}
}
if strings.TrimSpace(tok.AccessToken) == "" {
return "", errMissingAccessToken
}
return fetchUserEmailWithURL(ctx, tok.AccessToken, userinfoURL)
}