fix(auth): verify account matches authorized email
This commit is contained in:
parent
b14c0ce908
commit
4af0cb83f1
@ -24,6 +24,7 @@ var (
|
||||
startManageServer = googleauth.StartManageServer
|
||||
checkRefreshToken = googleauth.CheckRefreshToken
|
||||
ensureKeychainAccess = secrets.EnsureKeychainAccess
|
||||
fetchAuthorizedEmail = googleauth.EmailForRefreshToken
|
||||
)
|
||||
|
||||
func ensureKeychainAccessIfNeeded() error {
|
||||
@ -37,6 +38,10 @@ func ensureKeychainAccessIfNeeded() error {
|
||||
return ensureKeychainAccess()
|
||||
}
|
||||
|
||||
func normalizeEmail(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
type AuthCmd struct {
|
||||
Credentials AuthCredentialsCmd `cmd:"" name:"credentials" help:"Store OAuth client credentials"`
|
||||
Add AuthAddCmd `cmd:"" name:"add" help:"Authorize and store a refresh token"`
|
||||
@ -337,7 +342,7 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
return fmt.Errorf("no services selected")
|
||||
}
|
||||
|
||||
scopes, err := googleauth.ScopesForServices(services)
|
||||
scopes, err := googleauth.ScopesForManage(services)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -357,6 +362,14 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
authorizedEmail, err := fetchAuthorizedEmail(ctx, refreshToken, scopes, 15*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch authorized email: %w", err)
|
||||
}
|
||||
if normalizeEmail(authorizedEmail) != normalizeEmail(c.Email) {
|
||||
return fmt.Errorf("authorized as %s, expected %s", authorizedEmail, c.Email)
|
||||
}
|
||||
|
||||
store, err := openSecretsStore()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -367,8 +380,8 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
}
|
||||
sort.Strings(serviceNames)
|
||||
|
||||
if err := store.SetToken(c.Email, secrets.Token{
|
||||
Email: c.Email,
|
||||
if err := store.SetToken(authorizedEmail, secrets.Token{
|
||||
Email: authorizedEmail,
|
||||
Services: serviceNames,
|
||||
Scopes: scopes,
|
||||
RefreshToken: refreshToken,
|
||||
@ -378,11 +391,11 @@ func (c *AuthAddCmd) Run(ctx context.Context) error {
|
||||
if outfmt.IsJSON(ctx) {
|
||||
return outfmt.WriteJSON(os.Stdout, map[string]any{
|
||||
"stored": true,
|
||||
"email": c.Email,
|
||||
"email": authorizedEmail,
|
||||
"services": serviceNames,
|
||||
})
|
||||
}
|
||||
u.Out().Printf("email\t%s", c.Email)
|
||||
u.Out().Printf("email\t%s", authorizedEmail)
|
||||
u.Out().Printf("services\t%s", strings.Join(serviceNames, ","))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/gogcli/internal/googleauth"
|
||||
"github.com/steipete/gogcli/internal/outfmt"
|
||||
@ -19,10 +20,12 @@ func TestAuthAddCmd_JSON_More(t *testing.T) {
|
||||
origOpen := openSecretsStore
|
||||
origAuth := authorizeGoogle
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
openSecretsStore = origOpen
|
||||
authorizeGoogle = origAuth
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
store := newMemSecretsStore()
|
||||
@ -33,6 +36,9 @@ func TestAuthAddCmd_JSON_More(t *testing.T) {
|
||||
}
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
return "a@b.com", nil
|
||||
}
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
|
||||
u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: io.Discard, Color: "never"})
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/gogcli/internal/googleauth"
|
||||
"github.com/steipete/gogcli/internal/secrets"
|
||||
@ -15,10 +16,12 @@ func TestAuthAddCmd_JSON(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
@ -31,6 +34,9 @@ func TestAuthAddCmd_JSON(t *testing.T) {
|
||||
gotOpts = opts
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
return "user@example.com", nil
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
@ -80,10 +86,12 @@ func TestAuthAddCmd_KeychainError(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
// Simulate keychain locked error
|
||||
@ -96,6 +104,10 @@ func TestAuthAddCmd_KeychainError(t *testing.T) {
|
||||
authCalled = true
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
t.Fatal("fetchAuthorizedEmail should not be called when keychain check fails")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
store := newMemSecretsStore()
|
||||
openSecretsStore = func() (secrets.Store, error) { return store, nil }
|
||||
@ -118,10 +130,12 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
@ -134,6 +148,9 @@ func TestAuthAddCmd_DefaultServices_UserPreset(t *testing.T) {
|
||||
gotOpts = opts
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
return "user@example.com", nil
|
||||
}
|
||||
|
||||
_ = captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
@ -179,3 +196,33 @@ func TestAuthAddCmd_KeepRejected(t *testing.T) {
|
||||
t.Fatalf("authorizeGoogle should not be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAddCmd_EmailMismatch(t *testing.T) {
|
||||
origAuth := authorizeGoogle
|
||||
origOpen := openSecretsStore
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
authorizeGoogle = origAuth
|
||||
openSecretsStore = origOpen
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
openSecretsStore = func() (secrets.Store, error) { return newMemSecretsStore(), nil }
|
||||
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) {
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
return "actual@example.com", nil
|
||||
}
|
||||
|
||||
err := Execute([]string{"auth", "add", "expected@example.com"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected mismatch error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "authorized as actual@example.com") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ func newMemSecretsStore() *memSecretsStore {
|
||||
return &memSecretsStore{tokens: make(map[string]secrets.Token)}
|
||||
}
|
||||
|
||||
func normalizeEmail(s string) string {
|
||||
func normalizeEmailTest(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ func (s *memSecretsStore) Keys() ([]string, error) {
|
||||
}
|
||||
|
||||
func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error {
|
||||
email = normalizeEmail(email)
|
||||
email = normalizeEmailTest(email)
|
||||
if email == "" {
|
||||
return errors.New("missing email")
|
||||
}
|
||||
@ -52,7 +52,7 @@ func (s *memSecretsStore) SetToken(email string, tok secrets.Token) error {
|
||||
}
|
||||
|
||||
func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) {
|
||||
email = normalizeEmail(email)
|
||||
email = normalizeEmailTest(email)
|
||||
if email == "" {
|
||||
return secrets.Token{}, errors.New("missing email")
|
||||
}
|
||||
@ -63,7 +63,7 @@ func (s *memSecretsStore) GetToken(email string) (secrets.Token, error) {
|
||||
}
|
||||
|
||||
func (s *memSecretsStore) DeleteToken(email string) error {
|
||||
email = normalizeEmail(email)
|
||||
email = normalizeEmailTest(email)
|
||||
if email == "" {
|
||||
return errors.New("missing email")
|
||||
}
|
||||
|
||||
@ -231,15 +231,18 @@ func TestAuthAdd_TextOutput(t *testing.T) {
|
||||
origOpen := openSecretsStore
|
||||
origAuth := authorizeGoogle
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
openSecretsStore = origOpen
|
||||
authorizeGoogle = origAuth
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
store := newMemStore()
|
||||
openSecretsStore = func() (secrets.Store, error) { return store, nil }
|
||||
authorizeGoogle = func(context.Context, googleauth.AuthorizeOptions) (string, error) { return "rt", nil }
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) { return "a@b.com", nil }
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
|
||||
var outBuf strings.Builder
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/gogcli/internal/googleauth"
|
||||
"github.com/steipete/gogcli/internal/secrets"
|
||||
@ -13,10 +14,12 @@ func TestExecute_AuthAdd_JSON(t *testing.T) {
|
||||
origOpen := openSecretsStore
|
||||
origAuth := authorizeGoogle
|
||||
origKeychain := ensureKeychainAccess
|
||||
origFetch := fetchAuthorizedEmail
|
||||
t.Cleanup(func() {
|
||||
openSecretsStore = origOpen
|
||||
authorizeGoogle = origAuth
|
||||
ensureKeychainAccess = origKeychain
|
||||
fetchAuthorizedEmail = origFetch
|
||||
})
|
||||
|
||||
ensureKeychainAccess = func() error { return nil }
|
||||
@ -31,6 +34,9 @@ func TestExecute_AuthAdd_JSON(t *testing.T) {
|
||||
gotOpts.Scopes = append([]string{}, opts.Scopes...)
|
||||
return "rt", nil
|
||||
}
|
||||
fetchAuthorizedEmail = func(context.Context, string, []string, time.Duration) (string, error) {
|
||||
return "a@b.com", nil
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
|
||||
56
internal/googleauth/token_email.go
Normal file
56
internal/googleauth/token_email.go
Normal file
@ -0,0 +1,56 @@
|
||||
package googleauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// EmailForRefreshToken exchanges a refresh token and returns the authorized email address.
|
||||
func EmailForRefreshToken(ctx context.Context, refreshToken string, scopes []string, timeout time.Duration) (string, error) {
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return "", errMissingToken
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 15 * time.Second
|
||||
}
|
||||
|
||||
creds, err := readClientCredentials()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read credentials: %w", err)
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: creds.ClientID,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
Endpoint: oauthEndpoint,
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: timeout})
|
||||
|
||||
ts := cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken})
|
||||
tok, err := ts.Token()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh access token: %w", err)
|
||||
}
|
||||
|
||||
if raw, ok := tok.Extra("id_token").(string); ok {
|
||||
if email, err := emailFromIDToken(raw); err == nil {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tok.AccessToken) == "" {
|
||||
return "", errMissingAccessToken
|
||||
}
|
||||
|
||||
return fetchUserEmailWithURL(ctx, tok.AccessToken, userinfoURL)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user