fix(tracking): rotate email tracking keys
This commit is contained in:
parent
7b6b161236
commit
e98f44d665
@ -25,6 +25,7 @@
|
||||
### Fixed
|
||||
- Backup: split Gmail checkpoint commits by row count and plaintext byte size so large messages stay below GitHub's blob limit.
|
||||
- Auth: keep `gog auth list` and `gog auth tokens list` useful when one file-keyring token cannot be decrypted; unreadable entries are now reported instead of aborting the whole listing. (#377)
|
||||
- Email tracking: add versioned tracking-key rotation so new pixels use the current key while old tracking ids keep decrypting through prior keys. (#293)
|
||||
- Email tracking: deduplicate repeated pixel opens and cap recorded opens per IP per hour to reduce D1 abuse from replay or high-volume requests. (#294)
|
||||
- Email tracking: add daily Worker retention cleanup for open rows older than 90 days and cap admin `/opens` responses at 500 rows. (#292)
|
||||
- Email tracking: make `gmail track setup --deploy` reusable with existing D1 databases and valid temporary Wrangler configs.
|
||||
|
||||
@ -336,6 +336,8 @@ Generated from `gog schema --json`.
|
||||
- `gog gmail (mail,email) thread (threads,read) get (info,show) <threadId> [flags]` - Get a thread with all messages (optionally download attachments)
|
||||
- `gog gmail (mail,email) thread (threads,read) modify (update,edit,set) <threadId> [flags]` - Modify labels on all messages in a thread
|
||||
- `gog gmail (mail,email) track <command>` - Email open tracking
|
||||
- `gog gmail (mail,email) track key <command>` - Manage tracking encryption keys
|
||||
- `gog gmail (mail,email) track key rotate [flags]` - Rotate tracking encryption key
|
||||
- `gog gmail (mail,email) track opens [<tracking-id>] [flags]` - Query email opens
|
||||
- `gog gmail (mail,email) track setup [flags]` - Set up email tracking (deploy Cloudflare Worker)
|
||||
- `gog gmail (mail,email) track status` - Show tracking configuration status
|
||||
|
||||
@ -15,10 +15,12 @@ Location:
|
||||
|
||||
Expected bindings:
|
||||
- D1 database binding: `DB`
|
||||
- Secrets: `TRACKING_KEY`, `ADMIN_KEY`
|
||||
- Secrets: `TRACKING_KEY`, `TRACKING_KEY_V<N>`, `TRACKING_CURRENT_KEY_VERSION`, `ADMIN_KEY`
|
||||
|
||||
`wrangler.toml` is the local template; deployments set the real D1 database id.
|
||||
|
||||
`TRACKING_KEY` remains as the current-key fallback for legacy deployments and legacy unversioned tracking ids. New rotated deployments also set `TRACKING_KEY_V1`, `TRACKING_KEY_V2`, etc. The Worker reads the one-byte version prefix from new tracking ids, uses the matching `TRACKING_KEY_V<N>` when present, and falls back through active keys for older unversioned ids.
|
||||
|
||||
## Routes (high-level)
|
||||
|
||||
- Pixel:
|
||||
|
||||
@ -34,7 +34,8 @@ gog gmail track setup --worker-url https://gog-email-tracker.<acct>.workers.dev
|
||||
|
||||
This writes a local config file containing:
|
||||
- `worker_url` (base URL)
|
||||
- per-account tracking keys are stored in your keychain/keyring (not in the JSON file)
|
||||
- the active tracking key version
|
||||
- per-account tracking/admin keys are stored in your keychain/keyring (not in the JSON file)
|
||||
|
||||
Optional: auto-provision + deploy with wrangler:
|
||||
|
||||
@ -62,24 +63,44 @@ Provision secrets (use values printed by `gog gmail track setup`):
|
||||
|
||||
```sh
|
||||
pnpm exec wrangler secret put TRACKING_KEY
|
||||
pnpm exec wrangler secret put TRACKING_KEY_V1
|
||||
pnpm exec wrangler secret put TRACKING_CURRENT_KEY_VERSION
|
||||
pnpm exec wrangler secret put ADMIN_KEY
|
||||
```
|
||||
|
||||
Create and migrate D1:
|
||||
Create D1:
|
||||
|
||||
```sh
|
||||
pnpm exec wrangler d1 create gog-email-tracker
|
||||
pnpm exec wrangler d1 execute <db> --file schema.sql
|
||||
```
|
||||
|
||||
Update `wrangler.toml` to reference the D1 `database_id`, then deploy:
|
||||
Update `wrangler.toml` to reference the D1 `database_id`, then migrate and deploy:
|
||||
|
||||
```sh
|
||||
pnpm exec wrangler d1 execute gog-email-tracker --file schema.sql --remote
|
||||
pnpm exec wrangler deploy
|
||||
```
|
||||
|
||||
`wrangler.toml` includes a daily cron trigger for retention cleanup. After deploy, Cloudflare calls the Worker once per day and the Worker deletes open rows older than 90 days.
|
||||
|
||||
## Rotate tracking keys
|
||||
|
||||
Rotate the pixel encryption key without invalidating old tracking ids:
|
||||
|
||||
```sh
|
||||
gog gmail track key rotate
|
||||
```
|
||||
|
||||
The command generates the next key version, deploys all active `TRACKING_KEY_V<N>` secrets plus `TRACKING_CURRENT_KEY_VERSION`, then stores the new current version in local config. Legacy unversioned tracking ids still decrypt through the stored `TRACKING_KEY` fallback.
|
||||
|
||||
For local-only testing:
|
||||
|
||||
```sh
|
||||
gog gmail track key rotate --no-deploy
|
||||
```
|
||||
|
||||
Do not send newly tracked mail after `--no-deploy` until the Worker has the matching versioned secret, or new pixels will not decrypt.
|
||||
|
||||
## Send tracked mail
|
||||
|
||||
Tracked email constraints:
|
||||
@ -133,6 +154,7 @@ gog gmail track status
|
||||
|
||||
- `required: --worker-url`: run `gog gmail track setup --worker-url …` first (or pass `--worker-url` again).
|
||||
- `401`/`403` on `/opens`: admin key mismatch; redeploy secrets and re-run `track setup` if needed.
|
||||
- New tracked messages do not show opens after key rotation: verify the Worker has `TRACKING_KEY_V<N>` for the current local `gmail track status` version and `TRACKING_CURRENT_KEY_VERSION` matches it.
|
||||
- No opens recorded:
|
||||
- ensure the HTML body contains the injected pixel (view “original” in your mail client).
|
||||
- some clients block images by default; “open” only happens after images load.
|
||||
|
||||
@ -5,4 +5,5 @@ type GmailTrackCmd struct {
|
||||
Setup GmailTrackSetupCmd `cmd:"" help:"Set up email tracking (deploy Cloudflare Worker)"`
|
||||
Opens GmailTrackOpensCmd `cmd:"" help:"Query email opens"`
|
||||
Status GmailTrackStatusCmd `cmd:"" help:"Show tracking configuration status"`
|
||||
Key GmailTrackKeyCmd `cmd:"" help:"Manage tracking encryption keys"`
|
||||
}
|
||||
|
||||
@ -36,6 +36,9 @@ func TestGmailTrackSetupAndStatus(t *testing.T) {
|
||||
if !strings.Contains(out, "configured\ttrue") {
|
||||
t.Fatalf("unexpected setup output: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "tracking_key_version\t1") {
|
||||
t.Fatalf("missing setup key version: %q", out)
|
||||
}
|
||||
|
||||
statusOut := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
@ -47,6 +50,68 @@ func TestGmailTrackSetupAndStatus(t *testing.T) {
|
||||
if !strings.Contains(statusOut, "configured\ttrue") {
|
||||
t.Fatalf("unexpected status output: %q", statusOut)
|
||||
}
|
||||
if !strings.Contains(statusOut, "tracking_key_version\t1") {
|
||||
t.Fatalf("missing status key version: %q", statusOut)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailTrackKeyRotateNoDeploy(t *testing.T) {
|
||||
setupTrackingEnv(t)
|
||||
|
||||
_ = captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{"--account", "a@b.com", "--no-input", "gmail", "track", "setup", "--worker-url", "https://example.com"}); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
rotateOut := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{"--account", "a@b.com", "--no-input", "gmail", "track", "key", "rotate", "--no-deploy"}); err != nil {
|
||||
t.Fatalf("rotate: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
if !strings.Contains(rotateOut, "tracking_key_version\t2") {
|
||||
t.Fatalf("unexpected rotate output: %q", rotateOut)
|
||||
}
|
||||
if !strings.Contains(rotateOut, "tracking_key_versions\t1,2") {
|
||||
t.Fatalf("unexpected rotate versions: %q", rotateOut)
|
||||
}
|
||||
if !strings.Contains(rotateOut, "deployed\tfalse") {
|
||||
t.Fatalf("unexpected rotate deployed output: %q", rotateOut)
|
||||
}
|
||||
|
||||
statusOut := captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{"--account", "a@b.com", "gmail", "track", "status"}); err != nil {
|
||||
t.Fatalf("status: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
if !strings.Contains(statusOut, "tracking_key_version\t2") {
|
||||
t.Fatalf("missing rotated status key version: %q", statusOut)
|
||||
}
|
||||
|
||||
_ = captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{"--account", "a@b.com", "--no-input", "gmail", "track", "setup", "--worker-url", "https://example.com"}); err != nil {
|
||||
t.Fatalf("rerun setup: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
statusOut = captureStdout(t, func() {
|
||||
_ = captureStderr(t, func() {
|
||||
if err := Execute([]string{"--account", "a@b.com", "gmail", "track", "status"}); err != nil {
|
||||
t.Fatalf("status after setup rerun: %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
if !strings.Contains(statusOut, "tracking_key_version\t2") {
|
||||
t.Fatalf("setup rerun lost rotated key version: %q", statusOut)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailTrackStatus_NotConfigured(t *testing.T) {
|
||||
|
||||
146
internal/cmd/gmail_track_key.go
Normal file
146
internal/cmd/gmail_track_key.go
Normal file
@ -0,0 +1,146 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/steipete/gogcli/internal/tracking"
|
||||
"github.com/steipete/gogcli/internal/ui"
|
||||
)
|
||||
|
||||
type GmailTrackKeyCmd struct {
|
||||
Rotate GmailTrackKeyRotateCmd `cmd:"" help:"Rotate tracking encryption key"`
|
||||
}
|
||||
|
||||
type GmailTrackKeyRotateCmd struct {
|
||||
NoDeploy bool `name:"no-deploy" help:"Update local tracking keys without deploying the Worker"`
|
||||
WorkerDir string `name:"worker-dir" help:"Worker directory (default: internal/tracking/worker)"`
|
||||
}
|
||||
|
||||
func (c *GmailTrackKeyRotateCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
u := ui.FromContext(ctx)
|
||||
account, cfg, err := loadTrackingConfigForAccount(flags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("tracking not configured; run 'gog gmail track setup' first")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AdminKey) == "" {
|
||||
return fmt.Errorf("tracking admin key not configured; run 'gog gmail track setup' again")
|
||||
}
|
||||
|
||||
currentVersion := cfg.TrackingCurrentKeyVersion
|
||||
if currentVersion <= 0 {
|
||||
currentVersion = 1
|
||||
}
|
||||
knownVersions := tracking.NormalizeTrackingKeyVersions(cfg.TrackingKeyVersions, currentVersion)
|
||||
keys, currentVersion, err := tracking.LoadTrackingKeys(account, knownVersions, currentVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load tracking keys: %w", err)
|
||||
}
|
||||
if len(keys) == 0 && strings.TrimSpace(cfg.TrackingKey) != "" {
|
||||
keys[1] = cfg.TrackingKey
|
||||
currentVersion = 1
|
||||
}
|
||||
|
||||
nextVersion := nextTrackingKeyVersion(keys, currentVersion)
|
||||
nextKey, err := tracking.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate tracking key: %w", err)
|
||||
}
|
||||
keys[nextVersion] = nextKey
|
||||
|
||||
versions := tracking.NormalizeTrackingKeyVersions(mapKeys(keys), nextVersion)
|
||||
request := map[string]any{
|
||||
"account": account,
|
||||
"worker_name": cfg.WorkerName,
|
||||
"database_name": cfg.DatabaseName,
|
||||
"tracking_current_key_version": nextVersion,
|
||||
"tracking_key_versions": versions,
|
||||
"deploy": !c.NoDeploy,
|
||||
}
|
||||
if err := dryRunExit(ctx, flags, "gmail.track.key.rotate", request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.NoDeploy {
|
||||
workerName := strings.TrimSpace(cfg.WorkerName)
|
||||
if workerName == "" {
|
||||
return fmt.Errorf("tracking worker name not configured; run 'gog gmail track setup' again")
|
||||
}
|
||||
dbName := strings.TrimSpace(cfg.DatabaseName)
|
||||
if dbName == "" {
|
||||
dbName = workerName
|
||||
}
|
||||
workerDir := c.WorkerDir
|
||||
if workerDir == "" {
|
||||
workerDir = filepath.Join("internal", "tracking", "worker")
|
||||
}
|
||||
|
||||
dbID, deployErr := tracking.DeployWorker(ctx, u.Err(), tracking.DeployOptions{
|
||||
WorkerDir: workerDir,
|
||||
WorkerName: workerName,
|
||||
DatabaseName: dbName,
|
||||
TrackingKeys: keys,
|
||||
TrackingCurrentVersion: nextVersion,
|
||||
AdminKey: cfg.AdminKey,
|
||||
})
|
||||
if deployErr != nil {
|
||||
return deployErr
|
||||
}
|
||||
cfg.DatabaseID = dbID
|
||||
cfg.DatabaseName = dbName
|
||||
cfg.WorkerName = workerName
|
||||
}
|
||||
|
||||
if err := tracking.SaveTrackingKeys(account, keys, nextVersion, cfg.AdminKey); err != nil {
|
||||
return fmt.Errorf("save tracking keys: %w", err)
|
||||
}
|
||||
|
||||
cfg.SecretsInKeyring = true
|
||||
cfg.TrackingKey = ""
|
||||
cfg.TrackingCurrentKeyVersion = nextVersion
|
||||
cfg.TrackingKeyVersions = versions
|
||||
if err := tracking.SaveConfig(account, cfg); err != nil {
|
||||
return fmt.Errorf("save tracking config: %w", err)
|
||||
}
|
||||
|
||||
u.Out().Printf("tracking_key_rotated\ttrue")
|
||||
u.Out().Printf("tracking_key_version\t%d", nextVersion)
|
||||
u.Out().Printf("tracking_key_versions\t%s", formatTrackingKeyVersions(versions))
|
||||
u.Out().Printf("deployed\t%t", !c.NoDeploy)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func nextTrackingKeyVersion(keys map[int]string, currentVersion int) int {
|
||||
next := currentVersion
|
||||
for version := range keys {
|
||||
if version > next {
|
||||
next = version
|
||||
}
|
||||
}
|
||||
|
||||
return next + 1
|
||||
}
|
||||
|
||||
func mapKeys(values map[int]string) []int {
|
||||
keys := make([]int, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func formatTrackingKeyVersions(versions []int) string {
|
||||
parts := make([]string, 0, len(versions))
|
||||
for _, version := range versions {
|
||||
parts = append(parts, fmt.Sprintf("%d", version))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
@ -83,7 +83,29 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
return usage("required: --worker-url")
|
||||
}
|
||||
|
||||
explicitTrackingKey := strings.TrimSpace(c.TrackingKey) != ""
|
||||
key := strings.TrimSpace(c.TrackingKey)
|
||||
currentVersion := cfg.TrackingCurrentKeyVersion
|
||||
if currentVersion <= 0 {
|
||||
currentVersion = 1
|
||||
}
|
||||
|
||||
trackingKeys := map[int]string{}
|
||||
if !explicitTrackingKey {
|
||||
versions := tracking.NormalizeTrackingKeyVersions(cfg.TrackingKeyVersions, currentVersion)
|
||||
if len(versions) > 0 {
|
||||
loadedKeys, loadedCurrentVersion, loadErr := tracking.LoadTrackingKeys(account, versions, currentVersion)
|
||||
if loadErr != nil {
|
||||
return fmt.Errorf("load tracking keys: %w", loadErr)
|
||||
}
|
||||
|
||||
if len(loadedKeys) > 0 {
|
||||
trackingKeys = loadedKeys
|
||||
currentVersion = loadedCurrentVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(cfg.TrackingKey)
|
||||
}
|
||||
@ -93,6 +115,15 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
return fmt.Errorf("generate tracking key: %w", err)
|
||||
}
|
||||
}
|
||||
if len(trackingKeys) == 0 || explicitTrackingKey {
|
||||
currentVersion = 1
|
||||
trackingKeys = map[int]string{currentVersion: key}
|
||||
}
|
||||
if strings.TrimSpace(trackingKeys[currentVersion]) == "" {
|
||||
trackingKeys[currentVersion] = key
|
||||
}
|
||||
key = trackingKeys[currentVersion]
|
||||
versions := tracking.NormalizeTrackingKeyVersions(mapKeys(trackingKeys), currentVersion)
|
||||
|
||||
adminKey := strings.TrimSpace(c.AdminKey)
|
||||
if adminKey == "" {
|
||||
@ -111,19 +142,21 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
|
||||
// Avoid touching keyring and avoid provisioning/deploying in dry-run mode.
|
||||
if err := dryRunExit(ctx, flags, "gmail.track.setup", map[string]any{
|
||||
"account": account,
|
||||
"worker_url": c.WorkerURL,
|
||||
"worker_name": workerName,
|
||||
"database_name": c.DatabaseName,
|
||||
"deploy": c.Deploy,
|
||||
"worker_dir": c.WorkerDir,
|
||||
"tracking_key_set": strings.TrimSpace(key) != "",
|
||||
"admin_key_set": strings.TrimSpace(adminKey) != "",
|
||||
"account": account,
|
||||
"worker_url": c.WorkerURL,
|
||||
"worker_name": workerName,
|
||||
"database_name": c.DatabaseName,
|
||||
"deploy": c.Deploy,
|
||||
"worker_dir": c.WorkerDir,
|
||||
"tracking_key_set": strings.TrimSpace(key) != "",
|
||||
"tracking_key_version": currentVersion,
|
||||
"tracking_key_versions": versions,
|
||||
"admin_key_set": strings.TrimSpace(adminKey) != "",
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tracking.SaveSecrets(account, key, adminKey); err != nil {
|
||||
if err := tracking.SaveTrackingKeys(account, trackingKeys, currentVersion, adminKey); err != nil {
|
||||
return fmt.Errorf("save tracking secrets: %w", err)
|
||||
}
|
||||
|
||||
@ -133,15 +166,19 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
cfg.DatabaseName = c.DatabaseName
|
||||
cfg.SecretsInKeyring = true
|
||||
cfg.TrackingKey = ""
|
||||
cfg.TrackingKeyVersions = versions
|
||||
cfg.TrackingCurrentKeyVersion = currentVersion
|
||||
cfg.AdminKey = ""
|
||||
|
||||
if c.Deploy {
|
||||
dbID, deployErr := tracking.DeployWorker(ctx, u.Err(), tracking.DeployOptions{
|
||||
WorkerDir: c.WorkerDir,
|
||||
WorkerName: workerName,
|
||||
DatabaseName: c.DatabaseName,
|
||||
TrackingKey: key,
|
||||
AdminKey: adminKey,
|
||||
WorkerDir: c.WorkerDir,
|
||||
WorkerName: workerName,
|
||||
DatabaseName: c.DatabaseName,
|
||||
TrackingKey: key,
|
||||
TrackingKeys: trackingKeys,
|
||||
TrackingCurrentVersion: currentVersion,
|
||||
AdminKey: adminKey,
|
||||
})
|
||||
if deployErr != nil {
|
||||
return deployErr
|
||||
@ -162,6 +199,7 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
u.Out().Printf("worker_url\t%s", cfg.WorkerURL)
|
||||
u.Out().Printf("worker_name\t%s", cfg.WorkerName)
|
||||
u.Out().Printf("database_name\t%s", cfg.DatabaseName)
|
||||
u.Out().Printf("tracking_key_version\t%d", cfg.TrackingCurrentKeyVersion)
|
||||
if cfg.DatabaseID != "" {
|
||||
u.Out().Printf("database_id\t%s", cfg.DatabaseID)
|
||||
}
|
||||
@ -172,11 +210,19 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
u.Err().Printf(" - cd %s", c.WorkerDir)
|
||||
u.Err().Println(" - use these values when prompted:")
|
||||
u.Err().Printf(" TRACKING_KEY=%s", key)
|
||||
for _, version := range versions {
|
||||
u.Err().Printf(" TRACKING_KEY_V%d=%s", version, trackingKeys[version])
|
||||
}
|
||||
u.Err().Printf(" TRACKING_CURRENT_KEY_VERSION=%d", currentVersion)
|
||||
u.Err().Printf(" ADMIN_KEY=%s", adminKey)
|
||||
u.Err().Printf(" - wrangler d1 create %s", c.DatabaseName)
|
||||
u.Err().Println(" - wrangler d1 execute <db> --file schema.sql --remote")
|
||||
u.Err().Printf(" - set wrangler.toml name=%s + database_id", cfg.WorkerName)
|
||||
u.Err().Println(" - wrangler d1 execute <db> --file schema.sql --remote")
|
||||
u.Err().Println(" - wrangler secret put TRACKING_KEY")
|
||||
for _, version := range versions {
|
||||
u.Err().Printf(" - wrangler secret put TRACKING_KEY_V%d", version)
|
||||
}
|
||||
u.Err().Println(" - wrangler secret put TRACKING_CURRENT_KEY_VERSION")
|
||||
u.Err().Println(" - wrangler secret put ADMIN_KEY")
|
||||
u.Err().Println(" - wrangler deploy")
|
||||
}
|
||||
|
||||
@ -39,6 +39,9 @@ func (c *GmailTrackStatusCmd) Run(ctx context.Context, flags *RootFlags) error {
|
||||
if strings.TrimSpace(cfg.DatabaseID) != "" {
|
||||
u.Out().Printf("database_id\t%s", cfg.DatabaseID)
|
||||
}
|
||||
if cfg.TrackingCurrentKeyVersion > 0 {
|
||||
u.Out().Printf("tracking_key_version\t%d", cfg.TrackingCurrentKeyVersion)
|
||||
}
|
||||
u.Out().Printf("admin_configured\t%t", strings.TrimSpace(cfg.AdminKey) != "")
|
||||
|
||||
return nil
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -18,14 +19,16 @@ const trackingConfigVersion = 1
|
||||
|
||||
// Config holds tracking configuration for a single account.
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
WorkerURL string `json:"worker_url"`
|
||||
WorkerName string `json:"worker_name,omitempty"`
|
||||
DatabaseName string `json:"database_name,omitempty"`
|
||||
DatabaseID string `json:"database_id,omitempty"`
|
||||
SecretsInKeyring bool `json:"secrets_in_keyring,omitempty"`
|
||||
TrackingKey string `json:"tracking_key,omitempty"`
|
||||
AdminKey string `json:"admin_key,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
WorkerURL string `json:"worker_url"`
|
||||
WorkerName string `json:"worker_name,omitempty"`
|
||||
DatabaseName string `json:"database_name,omitempty"`
|
||||
DatabaseID string `json:"database_id,omitempty"`
|
||||
SecretsInKeyring bool `json:"secrets_in_keyring,omitempty"`
|
||||
TrackingKey string `json:"tracking_key,omitempty"`
|
||||
TrackingKeyVersions []int `json:"tracking_key_versions,omitempty"`
|
||||
TrackingCurrentKeyVersion int `json:"tracking_current_key_version,omitempty"`
|
||||
AdminKey string `json:"admin_key,omitempty"`
|
||||
}
|
||||
|
||||
type fileConfig struct {
|
||||
@ -194,11 +197,43 @@ func hydrateConfig(account string, cfg *Config) (*Config, error) {
|
||||
if strings.TrimSpace(adminKey) != "" {
|
||||
cfg.AdminKey = adminKey
|
||||
}
|
||||
|
||||
if cfg.TrackingCurrentKeyVersion > 0 || len(cfg.TrackingKeyVersions) > 0 {
|
||||
versions := NormalizeTrackingKeyVersions(cfg.TrackingKeyVersions, cfg.TrackingCurrentKeyVersion)
|
||||
|
||||
keys, currentVersion, keyErr := LoadTrackingKeys(account, versions, cfg.TrackingCurrentKeyVersion)
|
||||
if keyErr != nil {
|
||||
return nil, keyErr
|
||||
}
|
||||
|
||||
if strings.TrimSpace(keys[currentVersion]) != "" {
|
||||
cfg.TrackingKey = keys[currentVersion]
|
||||
cfg.TrackingCurrentKeyVersion = currentVersion
|
||||
cfg.TrackingKeyVersions = NormalizeTrackingKeyVersions(versions, currentVersion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func NormalizeTrackingKeyVersions(versions []int, currentVersion int) []int {
|
||||
normalized := make([]int, 0, len(versions)+1)
|
||||
for _, version := range versions {
|
||||
if version > 0 && version <= 255 {
|
||||
normalized = append(normalized, version)
|
||||
}
|
||||
}
|
||||
|
||||
if currentVersion > 0 && currentVersion <= 255 {
|
||||
normalized = append(normalized, currentVersion)
|
||||
}
|
||||
|
||||
slices.Sort(normalized)
|
||||
|
||||
return slices.Compact(normalized)
|
||||
}
|
||||
|
||||
func shouldLoadTrackingSecrets(cfg *Config) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
|
||||
@ -8,9 +8,15 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
var errCiphertextTooShort = errors.New("ciphertext too short")
|
||||
var (
|
||||
errCiphertextTooShort = errors.New("ciphertext too short")
|
||||
errInvalidTrackingKeyVersion = errors.New("invalid tracking key version")
|
||||
errMissingCurrentTrackingKeyValue = errors.New("missing current tracking key version")
|
||||
errNoTrackingKeys = errors.New("no tracking keys configured")
|
||||
)
|
||||
|
||||
// PixelPayload is encrypted into the tracking pixel URL
|
||||
// to be decrypted by the worker.
|
||||
@ -20,8 +26,23 @@ type PixelPayload struct {
|
||||
SentAt int64 `json:"t"`
|
||||
}
|
||||
|
||||
// Encrypt encrypts a PixelPayload into a URL-safe base64 blob using AES-GCM
|
||||
// Encrypt encrypts a PixelPayload into a legacy URL-safe base64 blob using AES-GCM.
|
||||
func Encrypt(payload *PixelPayload, keyBase64 string) (string, error) {
|
||||
return encryptPayload(payload, keyBase64, 0)
|
||||
}
|
||||
|
||||
// EncryptWithVersion encrypts a PixelPayload and prefixes the ciphertext with
|
||||
// a one-byte key version so future rotations can select the right key.
|
||||
func EncryptWithVersion(payload *PixelPayload, keyBase64 string, version int) (string, error) {
|
||||
versionByte, err := trackingKeyVersionByte(version)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return encryptPayload(payload, keyBase64, versionByte)
|
||||
}
|
||||
|
||||
func encryptPayload(payload *PixelPayload, keyBase64 string, version byte) (string, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(keyBase64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode key: %w", err)
|
||||
@ -47,24 +68,92 @@ func Encrypt(payload *PixelPayload, keyBase64 string) (string, error) {
|
||||
return "", fmt.Errorf("nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := aead.Seal(nonce, nonce, plaintext, nil)
|
||||
prefix := nonce
|
||||
if version > 0 {
|
||||
prefix = make([]byte, 0, 1+len(nonce)+len(plaintext)+aead.Overhead())
|
||||
prefix = append(prefix, version)
|
||||
prefix = append(prefix, nonce...)
|
||||
}
|
||||
|
||||
ciphertext := aead.Seal(prefix, nonce, plaintext, nil)
|
||||
|
||||
// URL-safe base64 encode
|
||||
return base64.RawURLEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a URL-safe base64 blob using AES-GCM
|
||||
// Decrypt decrypts a URL-safe base64 blob using AES-GCM.
|
||||
func Decrypt(blob string, keyBase64 string) (*PixelPayload, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(keyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode key: %w", err)
|
||||
}
|
||||
return DecryptWithKeys(blob, map[int]string{1: keyBase64})
|
||||
}
|
||||
|
||||
// DecryptWithKeys decrypts versioned and legacy tracking blobs with active keys.
|
||||
func DecryptWithKeys(blob string, keysByVersion map[int]string) (*PixelPayload, error) {
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(blob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode blob: %w", err)
|
||||
}
|
||||
|
||||
versions := trackingKeyVersions(keysByVersion)
|
||||
if len(versions) == 0 {
|
||||
return nil, errNoTrackingKeys
|
||||
}
|
||||
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, errCiphertextTooShort
|
||||
}
|
||||
|
||||
versionedOrder := prioritizeVersion(versions, int(ciphertext[0]))
|
||||
|
||||
versionedPayload, versionedErr := tryDecryptVersions(ciphertext, keysByVersion, versionedOrder, 1)
|
||||
if versionedErr == nil {
|
||||
return versionedPayload, nil
|
||||
}
|
||||
|
||||
payload, err := tryDecryptVersions(ciphertext, keysByVersion, versions, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func tryDecryptVersions(ciphertext []byte, keysByVersion map[int]string, versions []int, nonceOffset int) (*PixelPayload, error) {
|
||||
var lastErr error
|
||||
|
||||
for _, version := range versions {
|
||||
key := keysByVersion[version]
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
plaintext, err := decryptRaw(ciphertext, key, nonceOffset)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
var payload PixelPayload
|
||||
if err := json.Unmarshal(plaintext, &payload); err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = errNoTrackingKeys
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func decryptRaw(ciphertext []byte, keyBase64 string, nonceOffset int) ([]byte, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(keyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new cipher: %w", err)
|
||||
@ -75,24 +164,55 @@ func Decrypt(blob string, keyBase64 string) (*PixelPayload, error) {
|
||||
return nil, fmt.Errorf("new gcm: %w", err)
|
||||
}
|
||||
|
||||
if len(ciphertext) < aead.NonceSize() {
|
||||
if len(ciphertext) < nonceOffset+aead.NonceSize() {
|
||||
return nil, errCiphertextTooShort
|
||||
}
|
||||
|
||||
nonce := ciphertext[:aead.NonceSize()]
|
||||
ciphertext = ciphertext[aead.NonceSize():]
|
||||
nonce := ciphertext[nonceOffset : nonceOffset+aead.NonceSize()]
|
||||
ciphertext = ciphertext[nonceOffset+aead.NonceSize():]
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
return nil, fmt.Errorf("open payload: %w", err)
|
||||
}
|
||||
|
||||
var payload PixelPayload
|
||||
if err := json.Unmarshal(plaintext, &payload); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal payload: %w", err)
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func trackingKeyVersions(keysByVersion map[int]string) []int {
|
||||
versions := make([]int, 0, len(keysByVersion))
|
||||
for version, key := range keysByVersion {
|
||||
if version > 0 && version <= 255 && key != "" {
|
||||
versions = append(versions, version)
|
||||
}
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
sort.Ints(versions)
|
||||
|
||||
return versions
|
||||
}
|
||||
|
||||
func trackingKeyVersionByte(version int) (byte, error) {
|
||||
if version < 1 || version > 255 {
|
||||
return 0, fmt.Errorf("%w: %d", errInvalidTrackingKeyVersion, version)
|
||||
}
|
||||
|
||||
return byte(version), nil // #nosec G115 -- version is constrained above.
|
||||
}
|
||||
|
||||
func prioritizeVersion(versions []int, preferred int) []int {
|
||||
if preferred < 1 || preferred > 255 {
|
||||
return versions
|
||||
}
|
||||
|
||||
prioritized := append([]int{}, versions...)
|
||||
for i, version := range prioritized {
|
||||
if version == preferred {
|
||||
return append([]int{version}, append(prioritized[:i], prioritized[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
return prioritized
|
||||
}
|
||||
|
||||
// GenerateKey generates a new 256-bit AES key as base64
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package tracking
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -78,3 +79,66 @@ func TestDecryptWithWrongKeyFails(t *testing.T) {
|
||||
t.Error("Expected error when decrypting with wrong key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptWithVersionDecryptsWithActiveKeys(t *testing.T) {
|
||||
oldKey, _ := GenerateKey()
|
||||
newKey, _ := GenerateKey()
|
||||
payload := &PixelPayload{
|
||||
Recipient: "test@example.com",
|
||||
SubjectHash: "abc123",
|
||||
SentAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
encrypted, err := EncryptWithVersion(payload, newKey, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptWithVersion failed: %v", err)
|
||||
}
|
||||
|
||||
raw, err := base64.RawURLEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("decode encrypted blob: %v", err)
|
||||
}
|
||||
|
||||
if got := int(raw[0]); got != 2 {
|
||||
t.Fatalf("version prefix = %d, want 2", got)
|
||||
}
|
||||
|
||||
decrypted, err := DecryptWithKeys(encrypted, map[int]string{
|
||||
1: oldKey,
|
||||
2: newKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptWithKeys failed: %v", err)
|
||||
}
|
||||
|
||||
if *decrypted != *payload {
|
||||
t.Fatalf("decrypted payload = %#v, want %#v", decrypted, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithKeysAcceptsLegacyBlobs(t *testing.T) {
|
||||
oldKey, _ := GenerateKey()
|
||||
newKey, _ := GenerateKey()
|
||||
payload := &PixelPayload{
|
||||
Recipient: "test@example.com",
|
||||
SubjectHash: "abc123",
|
||||
SentAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
encrypted, err := Encrypt(payload, oldKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := DecryptWithKeys(encrypted, map[int]string{
|
||||
1: oldKey,
|
||||
2: newKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptWithKeys legacy failed: %v", err)
|
||||
}
|
||||
|
||||
if *decrypted != *payload {
|
||||
t.Fatalf("decrypted payload = %#v, want %#v", decrypted, payload)
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,11 +18,13 @@ type DeployLogger interface {
|
||||
}
|
||||
|
||||
type DeployOptions struct {
|
||||
WorkerDir string
|
||||
WorkerName string
|
||||
DatabaseName string
|
||||
TrackingKey string
|
||||
AdminKey string
|
||||
WorkerDir string
|
||||
WorkerName string
|
||||
DatabaseName string
|
||||
TrackingKey string
|
||||
TrackingKeys map[int]string
|
||||
TrackingCurrentVersion int
|
||||
AdminKey string
|
||||
}
|
||||
|
||||
var (
|
||||
@ -96,7 +98,22 @@ func DeployWorker(ctx context.Context, logger DeployLogger, opts DeployOptions)
|
||||
return "", runErr
|
||||
}
|
||||
|
||||
if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(opts.TrackingKey+"\n"), "secret", "put", "TRACKING_KEY", "--name", opts.WorkerName); runErr != nil {
|
||||
trackingKeys, currentVersion, err := normalizeDeployTrackingKeys(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(trackingKeys[currentVersion]+"\n"), "secret", "put", "TRACKING_KEY", "--name", opts.WorkerName); runErr != nil {
|
||||
return "", runErr
|
||||
}
|
||||
|
||||
for _, version := range trackingKeyVersions(trackingKeys) {
|
||||
if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(trackingKeys[version]+"\n"), "secret", "put", fmt.Sprintf("TRACKING_KEY_V%d", version), "--name", opts.WorkerName); runErr != nil {
|
||||
return "", runErr
|
||||
}
|
||||
}
|
||||
|
||||
if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(fmt.Sprintf("%d\n", currentVersion)), "secret", "put", "TRACKING_CURRENT_KEY_VERSION", "--name", opts.WorkerName); runErr != nil {
|
||||
return "", runErr
|
||||
}
|
||||
|
||||
@ -115,6 +132,41 @@ func DeployWorker(ctx context.Context, logger DeployLogger, opts DeployOptions)
|
||||
return dbID, nil
|
||||
}
|
||||
|
||||
func normalizeDeployTrackingKeys(opts DeployOptions) (map[int]string, int, error) {
|
||||
trackingKeys := map[int]string{}
|
||||
|
||||
for version, key := range opts.TrackingKeys {
|
||||
if version < 1 || version > 255 {
|
||||
return nil, 0, fmt.Errorf("%w: %d", errInvalidTrackingKeyVersion, version)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(key) == "" {
|
||||
return nil, 0, errMissingTrackingKey
|
||||
}
|
||||
|
||||
trackingKeys[version] = key
|
||||
}
|
||||
|
||||
currentVersion := opts.TrackingCurrentVersion
|
||||
if currentVersion <= 0 {
|
||||
currentVersion = 1
|
||||
}
|
||||
|
||||
if len(trackingKeys) == 0 && strings.TrimSpace(opts.TrackingKey) != "" {
|
||||
trackingKeys[currentVersion] = opts.TrackingKey
|
||||
}
|
||||
|
||||
if len(trackingKeys) == 0 {
|
||||
return nil, 0, errMissingTrackingKey
|
||||
}
|
||||
|
||||
if strings.TrimSpace(trackingKeys[currentVersion]) == "" {
|
||||
return nil, 0, fmt.Errorf("%w: %d", errMissingCurrentTrackingKeyValue, currentVersion)
|
||||
}
|
||||
|
||||
return trackingKeys, currentVersion, nil
|
||||
}
|
||||
|
||||
func ensureD1Database(ctx context.Context, workerDir, dbName string) (string, error) {
|
||||
out, err := runWranglerCommandOutput(ctx, workerDir, nil, "d1", "create", dbName)
|
||||
if err != nil {
|
||||
|
||||
@ -23,7 +23,7 @@ func GeneratePixelURL(cfg *Config, recipient, subject string) (string, string, e
|
||||
SentAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
blob, err := Encrypt(payload, cfg.TrackingKey)
|
||||
blob, err := encryptTrackingPayload(payload, cfg)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("encrypt payload: %w", err)
|
||||
}
|
||||
@ -33,6 +33,14 @@ func GeneratePixelURL(cfg *Config, recipient, subject string) (string, string, e
|
||||
return pixelURL, blob, nil
|
||||
}
|
||||
|
||||
func encryptTrackingPayload(payload *PixelPayload, cfg *Config) (string, error) {
|
||||
if cfg.TrackingCurrentKeyVersion > 0 {
|
||||
return EncryptWithVersion(payload, cfg.TrackingKey, cfg.TrackingCurrentKeyVersion)
|
||||
}
|
||||
|
||||
return Encrypt(payload, cfg.TrackingKey)
|
||||
}
|
||||
|
||||
// GeneratePixelHTML returns HTML img tag for the tracking pixel
|
||||
func GeneratePixelHTML(pixelURL string) string {
|
||||
return fmt.Sprintf(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package tracking
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@ -31,6 +32,30 @@ func TestGeneratePixelURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePixelURLUsesCurrentKeyVersion(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
cfg := &Config{
|
||||
Enabled: true,
|
||||
WorkerURL: "https://test.workers.dev",
|
||||
TrackingKey: key,
|
||||
TrackingCurrentKeyVersion: 2,
|
||||
}
|
||||
|
||||
_, blob, err := GeneratePixelURL(cfg, "test@example.com", "Hello World")
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePixelURL failed: %v", err)
|
||||
}
|
||||
|
||||
raw, err := base64.RawURLEncoding.DecodeString(blob)
|
||||
if err != nil {
|
||||
t.Fatalf("decode blob: %v", err)
|
||||
}
|
||||
|
||||
if got := int(raw[0]); got != 2 {
|
||||
t.Fatalf("version prefix = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePixelURLNotConfigured(t *testing.T) {
|
||||
cfg := &Config{Enabled: false}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ package tracking
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/99designs/keyring"
|
||||
@ -36,7 +37,47 @@ func SaveSecrets(account, trackingKey, adminKey string) error {
|
||||
return errMissingAdminKey
|
||||
}
|
||||
|
||||
if err := secrets.SetSecret(scopedSecretKey(account, trackingKeySecretSuffix), []byte(trackingKey)); err != nil {
|
||||
if err := SaveTrackingKeys(account, map[int]string{1: trackingKey}, 1, adminKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SaveTrackingKeys(account string, trackingKeys map[int]string, currentVersion int, adminKey string) error {
|
||||
account = normalizeAccount(account)
|
||||
if account == "" {
|
||||
return errMissingAccount
|
||||
}
|
||||
|
||||
if len(trackingKeys) == 0 {
|
||||
return errMissingTrackingKey
|
||||
}
|
||||
|
||||
if adminKey == "" {
|
||||
return errMissingAdminKey
|
||||
}
|
||||
|
||||
for version, trackingKey := range trackingKeys {
|
||||
if version < 1 || version > 255 {
|
||||
return fmt.Errorf("%w: %d", errInvalidTrackingKeyVersion, version)
|
||||
}
|
||||
|
||||
if trackingKey == "" {
|
||||
return errMissingTrackingKey
|
||||
}
|
||||
|
||||
if err := secrets.SetSecret(scopedSecretKey(account, versionedTrackingKeySecretSuffix(version)), []byte(trackingKey)); err != nil {
|
||||
return fmt.Errorf("store tracking key v%d: %w", version, err)
|
||||
}
|
||||
}
|
||||
|
||||
currentKey := trackingKeys[currentVersion]
|
||||
if currentKey == "" {
|
||||
return fmt.Errorf("%w: %d", errMissingCurrentTrackingKeyValue, currentVersion)
|
||||
}
|
||||
|
||||
if err := secrets.SetSecret(scopedSecretKey(account, trackingKeySecretSuffix), []byte(currentKey)); err != nil {
|
||||
return fmt.Errorf("store tracking key: %w", err)
|
||||
}
|
||||
|
||||
@ -47,6 +88,57 @@ func SaveSecrets(account, trackingKey, adminKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadTrackingKeys(account string, knownVersions []int, currentVersion int) (map[int]string, int, error) {
|
||||
account = normalizeAccount(account)
|
||||
if account == "" {
|
||||
return nil, 0, errMissingAccount
|
||||
}
|
||||
|
||||
versions := NormalizeTrackingKeyVersions(knownVersions, currentVersion)
|
||||
if len(versions) == 0 {
|
||||
versions = []int{1}
|
||||
}
|
||||
|
||||
keys := map[int]string{}
|
||||
|
||||
for _, version := range versions {
|
||||
key, err := readSecretWithFallback(scopedSecretKey(account, versionedTrackingKeySecretSuffix(version)), "")
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read tracking key v%d: %w", version, err)
|
||||
}
|
||||
|
||||
if key != "" {
|
||||
keys[version] = key
|
||||
}
|
||||
}
|
||||
|
||||
if keys[1] == "" {
|
||||
legacyKey, err := readSecretWithFallback(scopedSecretKey(account, trackingKeySecretSuffix), legacyTrackingKeySecretKey)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read tracking key: %w", err)
|
||||
}
|
||||
|
||||
if legacyKey != "" {
|
||||
keys[1] = legacyKey
|
||||
}
|
||||
}
|
||||
|
||||
if currentVersion <= 0 {
|
||||
currentVersion = 1
|
||||
}
|
||||
|
||||
if keys[currentVersion] == "" {
|
||||
currentVersion = 0
|
||||
for version := range keys {
|
||||
if version > currentVersion {
|
||||
currentVersion = version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keys, currentVersion, nil
|
||||
}
|
||||
|
||||
func LoadSecrets(account string) (trackingKey, adminKey string, err error) {
|
||||
account = normalizeAccount(account)
|
||||
if account == "" {
|
||||
@ -76,6 +168,10 @@ func readSecretWithFallback(primary, legacy string) (string, error) {
|
||||
return "", fmt.Errorf("read secret: %w", err)
|
||||
}
|
||||
|
||||
if legacy == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
legacyVal, legacyErr := secrets.GetSecret(legacy)
|
||||
if legacyErr == nil {
|
||||
return string(legacyVal), nil
|
||||
@ -92,3 +188,7 @@ func scopedSecretKey(account, suffix string) string {
|
||||
account = strings.ReplaceAll(account, " ", "")
|
||||
return fmt.Sprintf("tracking/%s/%s", account, suffix)
|
||||
}
|
||||
|
||||
func versionedTrackingKeySecretSuffix(version int) string {
|
||||
return trackingKeySecretSuffix + "_v" + strconv.Itoa(version)
|
||||
}
|
||||
|
||||
@ -33,6 +33,40 @@ func TestSaveAndLoadSecrets(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoadTrackingKeys(t *testing.T) {
|
||||
setupTrackingKeyringEnv(t)
|
||||
|
||||
keys := map[int]string{
|
||||
1: "track-v1",
|
||||
2: "track-v2",
|
||||
}
|
||||
if err := SaveTrackingKeys("a@b.com", keys, 2, "admin"); err != nil {
|
||||
t.Fatalf("SaveTrackingKeys: %v", err)
|
||||
}
|
||||
|
||||
loaded, currentVersion, err := LoadTrackingKeys("a@b.com", []int{1, 2}, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadTrackingKeys: %v", err)
|
||||
}
|
||||
|
||||
if currentVersion != 2 {
|
||||
t.Fatalf("current version = %d, want 2", currentVersion)
|
||||
}
|
||||
|
||||
if loaded[1] != "track-v1" || loaded[2] != "track-v2" {
|
||||
t.Fatalf("unexpected tracking keys: %#v", loaded)
|
||||
}
|
||||
|
||||
track, admin, err := LoadSecrets("a@b.com")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSecrets: %v", err)
|
||||
}
|
||||
|
||||
if track != "track-v2" || admin != "admin" {
|
||||
t.Fatalf("unexpected current secrets: %q %q", track, admin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSecrets_LegacyFallback(t *testing.T) {
|
||||
setupTrackingKeyringEnv(t)
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { importKey, encrypt, decrypt } from './crypto';
|
||||
import { importKey, encrypt, decrypt, encryptWithVersion, decryptWithKeys } from './crypto';
|
||||
|
||||
describe('crypto', () => {
|
||||
const testKey = 'MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDE='; // 32 bytes base64
|
||||
const rotatedKey = 'MTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTI='; // 32 bytes base64
|
||||
|
||||
it('encrypts and decrypts payload', async () => {
|
||||
const key = await importKey(testKey);
|
||||
@ -28,4 +29,34 @@ describe('crypto', () => {
|
||||
|
||||
await expect(decrypt('invalid', key)).rejects.toThrow();
|
||||
});
|
||||
|
||||
it('decrypts versioned payloads with active keys', async () => {
|
||||
const key = await importKey(rotatedKey);
|
||||
const payload = { r: 'test@example.com', s: 'abc123', t: 1704067200 };
|
||||
|
||||
const encrypted = await encryptWithVersion(payload, key, 2);
|
||||
const base64 = encrypted.replace(/-/g, '+').replace(/_/g, '/');
|
||||
const padded = base64 + '='.repeat((4 - base64.length % 4) % 4);
|
||||
const raw = Uint8Array.from(atob(padded), c => c.charCodeAt(0));
|
||||
const decrypted = await decryptWithKeys(encrypted, {
|
||||
1: testKey,
|
||||
2: rotatedKey,
|
||||
});
|
||||
|
||||
expect(raw[0]).toBe(2);
|
||||
expect(decrypted).toEqual(payload);
|
||||
});
|
||||
|
||||
it('decrypts legacy payloads with rotated key sets', async () => {
|
||||
const key = await importKey(testKey);
|
||||
const payload = { r: 'test@example.com', s: 'abc123', t: 1704067200 };
|
||||
|
||||
const encrypted = await encrypt(payload, key);
|
||||
const decrypted = await decryptWithKeys(encrypted, {
|
||||
1: testKey,
|
||||
2: rotatedKey,
|
||||
});
|
||||
|
||||
expect(decrypted).toEqual(payload);
|
||||
});
|
||||
});
|
||||
|
||||
@ -3,6 +3,8 @@ import type { PixelPayload } from './types';
|
||||
const ALGORITHM = 'AES-GCM';
|
||||
const IV_LENGTH = 12;
|
||||
|
||||
export type TrackingKeys = Record<number, string>;
|
||||
|
||||
export async function importKey(base64Key: string): Promise<CryptoKey> {
|
||||
const keyBytes = Uint8Array.from(atob(base64Key), c => c.charCodeAt(0));
|
||||
return crypto.subtle.importKey(
|
||||
@ -15,25 +17,54 @@ export async function importKey(base64Key: string): Promise<CryptoKey> {
|
||||
}
|
||||
|
||||
export async function decrypt(blob: string, key: CryptoKey): Promise<PixelPayload> {
|
||||
// URL-safe base64 decode
|
||||
const base64 = blob.replace(/-/g, '+').replace(/_/g, '/');
|
||||
const padded = base64 + '='.repeat((4 - base64.length % 4) % 4);
|
||||
const combined = Uint8Array.from(atob(padded), c => c.charCodeAt(0));
|
||||
const combined = decodeBlob(blob);
|
||||
const decrypted = await decryptRaw(combined, key, 0);
|
||||
return parsePayload(decrypted);
|
||||
}
|
||||
|
||||
const iv = combined.slice(0, IV_LENGTH);
|
||||
const ciphertext = combined.slice(IV_LENGTH);
|
||||
export async function decryptWithKeys(blob: string, keys: TrackingKeys): Promise<PixelPayload> {
|
||||
const combined = decodeBlob(blob);
|
||||
const versions = Object.keys(keys)
|
||||
.map(v => Number.parseInt(v, 10))
|
||||
.filter(v => Number.isFinite(v) && v > 0 && v <= 255 && keys[v]?.trim())
|
||||
.sort((a, b) => a - b);
|
||||
if (versions.length === 0 || combined.length === 0) {
|
||||
throw new Error('missing tracking keys');
|
||||
}
|
||||
|
||||
const decrypted = await crypto.subtle.decrypt(
|
||||
{ name: ALGORITHM, iv },
|
||||
key,
|
||||
ciphertext
|
||||
);
|
||||
const versionedOrder = prioritizeVersion(versions, combined[0]);
|
||||
const versioned = await tryDecryptVersions(combined, keys, versionedOrder, 1);
|
||||
if (versioned) {
|
||||
return versioned;
|
||||
}
|
||||
|
||||
const text = new TextDecoder().decode(decrypted);
|
||||
return JSON.parse(text) as PixelPayload;
|
||||
const legacy = await tryDecryptVersions(combined, keys, versions, 0);
|
||||
if (legacy) {
|
||||
return legacy;
|
||||
}
|
||||
|
||||
throw new Error('decrypt failed');
|
||||
}
|
||||
|
||||
export async function encrypt(payload: PixelPayload, key: CryptoKey): Promise<string> {
|
||||
return encryptRaw(payload, key, 0);
|
||||
}
|
||||
|
||||
export async function encryptWithVersion(payload: PixelPayload, key: CryptoKey, version: number): Promise<string> {
|
||||
if (!Number.isInteger(version) || version < 1 || version > 255) {
|
||||
throw new Error(`invalid key version: ${version}`);
|
||||
}
|
||||
|
||||
return encryptRaw(payload, key, version);
|
||||
}
|
||||
|
||||
function decodeBlob(blob: string): Uint8Array {
|
||||
const base64 = blob.replace(/-/g, '+').replace(/_/g, '/');
|
||||
const padded = base64 + '='.repeat((4 - base64.length % 4) % 4);
|
||||
return Uint8Array.from(atob(padded), c => c.charCodeAt(0));
|
||||
}
|
||||
|
||||
async function encryptRaw(payload: PixelPayload, key: CryptoKey, version: number): Promise<string> {
|
||||
const iv = crypto.getRandomValues(new Uint8Array(IV_LENGTH));
|
||||
const encoded = new TextEncoder().encode(JSON.stringify(payload));
|
||||
|
||||
@ -43,11 +74,71 @@ export async function encrypt(payload: PixelPayload, key: CryptoKey): Promise<st
|
||||
encoded
|
||||
);
|
||||
|
||||
const combined = new Uint8Array(IV_LENGTH + ciphertext.byteLength);
|
||||
combined.set(iv);
|
||||
combined.set(new Uint8Array(ciphertext), IV_LENGTH);
|
||||
const prefixLength = version > 0 ? 1 : 0;
|
||||
const combined = new Uint8Array(prefixLength + IV_LENGTH + ciphertext.byteLength);
|
||||
if (version > 0) {
|
||||
combined[0] = version;
|
||||
}
|
||||
combined.set(iv, prefixLength);
|
||||
combined.set(new Uint8Array(ciphertext), prefixLength + IV_LENGTH);
|
||||
|
||||
// URL-safe base64 encode
|
||||
const base64 = btoa(String.fromCharCode(...combined));
|
||||
return base64.replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
}
|
||||
|
||||
async function tryDecryptVersions(
|
||||
combined: Uint8Array,
|
||||
keys: TrackingKeys,
|
||||
versions: number[],
|
||||
nonceOffset: number
|
||||
): Promise<PixelPayload | null> {
|
||||
for (const version of versions) {
|
||||
const key = keys[version];
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const importedKey = await importKey(key);
|
||||
const decrypted = await decryptRaw(combined, importedKey, nonceOffset);
|
||||
return parsePayload(decrypted);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async function decryptRaw(combined: Uint8Array, key: CryptoKey, nonceOffset: number): Promise<ArrayBuffer> {
|
||||
if (combined.length < nonceOffset + IV_LENGTH) {
|
||||
throw new Error('ciphertext too short');
|
||||
}
|
||||
|
||||
const iv = combined.slice(nonceOffset, nonceOffset + IV_LENGTH);
|
||||
const ciphertext = combined.slice(nonceOffset + IV_LENGTH);
|
||||
|
||||
return crypto.subtle.decrypt(
|
||||
{ name: ALGORITHM, iv },
|
||||
key,
|
||||
ciphertext
|
||||
);
|
||||
}
|
||||
|
||||
function parsePayload(payload: ArrayBuffer): PixelPayload {
|
||||
const text = new TextDecoder().decode(payload);
|
||||
return JSON.parse(text) as PixelPayload;
|
||||
}
|
||||
|
||||
function prioritizeVersion(versions: number[], preferred: number): number[] {
|
||||
if (!Number.isInteger(preferred) || preferred < 1 || preferred > 255) {
|
||||
return versions;
|
||||
}
|
||||
|
||||
const index = versions.indexOf(preferred);
|
||||
if (index < 0) {
|
||||
return versions;
|
||||
}
|
||||
|
||||
return [versions[index], ...versions.slice(0, index), ...versions.slice(index + 1)];
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import worker from './index';
|
||||
import { encrypt, importKey } from './crypto';
|
||||
import { encrypt, encryptWithVersion, importKey } from './crypto';
|
||||
|
||||
const testKey = 'MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDE=';
|
||||
|
||||
@ -117,6 +117,11 @@ async function encryptedBlob(): Promise<string> {
|
||||
return encrypt({ r: 'to@example.com', s: 'abcdef', t: Math.floor(Date.now() / 1000) - 10 }, key);
|
||||
}
|
||||
|
||||
async function encryptedVersionedBlob(): Promise<string> {
|
||||
const key = await importKey(testKey);
|
||||
return encryptWithVersion({ r: 'to@example.com', s: 'abcdef', t: Math.floor(Date.now() / 1000) - 10 }, key, 2);
|
||||
}
|
||||
|
||||
describe('tracking worker pixel rate limiting', () => {
|
||||
it('deduplicates repeated opens for the same tracking id, ip, and user agent', async () => {
|
||||
const db = new FakeD1();
|
||||
@ -129,6 +134,23 @@ describe('tracking worker pixel rate limiting', () => {
|
||||
expect(db.rows).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('records versioned tracking pixels with versioned Worker keys', async () => {
|
||||
const db = new FakeD1();
|
||||
const env = {
|
||||
DB: db as unknown as D1Database,
|
||||
TRACKING_KEY: 'wrong-current-key',
|
||||
TRACKING_KEY_V2: testKey,
|
||||
TRACKING_CURRENT_KEY_VERSION: '2',
|
||||
ADMIN_KEY: 'admin',
|
||||
};
|
||||
const blob = await encryptedVersionedBlob();
|
||||
|
||||
await worker.fetch(await pixelRequest(blob), env);
|
||||
|
||||
expect(db.rows).toHaveLength(1);
|
||||
expect(db.rows[0].recipient).toBe('to@example.com');
|
||||
});
|
||||
|
||||
it('silently skips inserts after the per-IP hourly cap', async () => {
|
||||
const db = new FakeD1();
|
||||
const env = { DB: db as unknown as D1Database, TRACKING_KEY: testKey, ADMIN_KEY: 'admin' };
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { Env, PixelPayload } from './types';
|
||||
import { importKey, decrypt } from './crypto';
|
||||
import { decryptWithKeys, type TrackingKeys } from './crypto';
|
||||
import { detectBot } from './bot';
|
||||
import { pixelResponse } from './pixel';
|
||||
|
||||
@ -52,11 +52,10 @@ async function handlePixel(request: Request, env: Env, path: string): Promise<Re
|
||||
// Extract blob from /p/:blob.gif
|
||||
const blob = path.slice(3, -4); // Remove '/p/' and '.gif'
|
||||
|
||||
const key = await importKey(env.TRACKING_KEY);
|
||||
let payload: PixelPayload;
|
||||
|
||||
try {
|
||||
payload = await decrypt(blob, key);
|
||||
payload = await decryptWithKeys(blob, trackingKeysFromEnv(env));
|
||||
} catch {
|
||||
// Still return pixel even if decryption fails (don't break email display)
|
||||
return pixelResponse();
|
||||
@ -160,11 +159,10 @@ async function purgeExpiredOpens(env: Env): Promise<void> {
|
||||
async function handleQuery(request: Request, env: Env, path: string): Promise<Response> {
|
||||
const blob = path.slice(3); // Remove '/q/'
|
||||
|
||||
const key = await importKey(env.TRACKING_KEY);
|
||||
let payload: PixelPayload;
|
||||
|
||||
try {
|
||||
payload = await decrypt(blob, key);
|
||||
payload = await decryptWithKeys(blob, trackingKeysFromEnv(env));
|
||||
} catch {
|
||||
return new Response('Invalid tracking ID', { status: 400 });
|
||||
}
|
||||
@ -259,3 +257,28 @@ function parseAdminLimit(raw: string | null): number {
|
||||
|
||||
return Math.min(parsed, MAX_ADMIN_LIMIT);
|
||||
}
|
||||
|
||||
function trackingKeysFromEnv(env: Env): TrackingKeys {
|
||||
const keys: TrackingKeys = {};
|
||||
for (const [name, value] of Object.entries(env)) {
|
||||
const match = /^TRACKING_KEY_V([1-9][0-9]*)$/.exec(name);
|
||||
if (!match || typeof value !== 'string' || value.trim() === '') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const version = Number.parseInt(match[1], 10);
|
||||
if (version >= 1 && version <= 255) {
|
||||
keys[version] = value;
|
||||
}
|
||||
}
|
||||
|
||||
const currentVersion = Number.parseInt(env.TRACKING_CURRENT_KEY_VERSION || '', 10);
|
||||
const legacyVersion = Number.isFinite(currentVersion) && currentVersion >= 1 && currentVersion <= 255
|
||||
? currentVersion
|
||||
: 1;
|
||||
if (env.TRACKING_KEY && !keys[legacyVersion]) {
|
||||
keys[legacyVersion] = env.TRACKING_KEY;
|
||||
}
|
||||
|
||||
return keys;
|
||||
}
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
export interface Env {
|
||||
DB: D1Database;
|
||||
TRACKING_KEY: string;
|
||||
TRACKING_CURRENT_KEY_VERSION?: string;
|
||||
ADMIN_KEY: string;
|
||||
[key: `TRACKING_KEY_V${number}`]: string | undefined;
|
||||
}
|
||||
|
||||
export interface PixelPayload {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user