feat: add command safety controls

This commit is contained in:
Peter Steinberger 2026-04-21 06:06:57 +01:00
parent 86a43def24
commit f6a94547d6
No known key found for this signature in database
18 changed files with 229 additions and 15 deletions

View File

@ -23,6 +23,9 @@ func newAuthCmd(flags *rootFlags) *cobra.Command {
Use: "auth",
Short: "Authenticate with WhatsApp (QR) and bootstrap sync",
RunE: func(cmd *cobra.Command, args []string) error {
if err := flags.requireWritable(); err != nil {
return err
}
ctx, stop := signalContext()
defer stop()
@ -146,6 +149,9 @@ func newAuthLogoutCmd(flags *rootFlags) *cobra.Command {
Use: "logout",
Short: "Logout (invalidate session)",
RunE: func(cmd *cobra.Command, args []string) error {
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -119,6 +119,9 @@ func newContactsRefreshCmd(flags *rootFlags) *cobra.Command {
Use: "refresh",
Short: "Import contacts from whatsmeow store into local DB",
RunE: func(cmd *cobra.Command, args []string) error {
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
@ -138,6 +141,7 @@ func newContactsRefreshCmd(flags *rootFlags) *cobra.Command {
var count int
for jid, info := range cs {
jid = canonicalCLIJID(jid)
_ = a.DB().UpsertContact(
jid.String(),
jid.User,
@ -173,6 +177,9 @@ func newContactsAliasCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(jid) == "" || strings.TrimSpace(alias) == "" {
return fmt.Errorf("--jid and --alias are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
a, lk, err := newApp(ctx, flags, false, false)
@ -198,6 +205,9 @@ func newContactsAliasCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(jid) == "" {
return fmt.Errorf("--jid is required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
a, lk, err := newApp(ctx, flags, false, false)
@ -235,6 +245,9 @@ func newContactsTagsCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(jid) == "" || strings.TrimSpace(tag) == "" {
return fmt.Errorf("--jid and --tag are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
a, lk, err := newApp(ctx, flags, false, false)
@ -261,6 +274,9 @@ func newContactsTagsCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(jid) == "" || strings.TrimSpace(tag) == "" {
return fmt.Errorf("--jid and --tag are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
a, lk, err := newApp(ctx, flags, false, false)

View File

@ -82,6 +82,9 @@ func newGroupsInviteLinkRevokeCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(jidStr) == "" {
return fmt.Errorf("--jid is required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()
@ -125,6 +128,9 @@ func newGroupsJoinCmd(flags *rootFlags) *cobra.Command {
if strings.TrimSpace(code) == "" {
return fmt.Errorf("--code is required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -34,6 +34,9 @@ func newGroupsParticipantsActionCmd(flags *rootFlags, action string) *cobra.Comm
if strings.TrimSpace(group) == "" || len(users) == 0 {
return fmt.Errorf("--jid and at least one --user are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -5,6 +5,13 @@ import (
"go.mau.fi/whatsmeow/types"
)
func canonicalCLIJID(jid types.JID) types.JID {
if jid.Server == types.DefaultUserServer {
return jid.ToNonAD()
}
return jid
}
func persistGroupInfo(db *store.DB, info *types.GroupInfo) error {
if info == nil {
return nil
@ -22,7 +29,7 @@ func persistGroupInfo(db *store.DB, info *types.GroupInfo) error {
}
ps = append(ps, store.GroupParticipant{
GroupJID: info.JID.String(),
UserJID: p.JID.String(),
UserJID: canonicalCLIJID(p.JID).String(),
Role: role,
})
}

View File

@ -15,6 +15,9 @@ func newGroupsRefreshCmd(flags *rootFlags) *cobra.Command {
Use: "refresh",
Short: "Fetch joined groups (live) and update local DB",
RunE: func(cmd *cobra.Command, args []string) error {
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -33,6 +33,9 @@ func newHistoryBackfillCmd(flags *rootFlags) *cobra.Command {
if chat == "" {
return fmt.Errorf("--chat is required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, stop := signalContext()
defer stop()
@ -70,8 +73,8 @@ func newHistoryBackfillCmd(flags *rootFlags) *cobra.Command {
}
cmd.Flags().StringVar(&chat, "chat", "", "chat JID")
cmd.Flags().IntVar(&count, "count", 50, "number of messages to request per on-demand sync (recommended: 50)")
cmd.Flags().IntVar(&requests, "requests", 1, "number of on-demand requests to attempt")
cmd.Flags().IntVar(&count, "count", app.DefaultBackfillCount, "number of messages to request per on-demand sync")
cmd.Flags().IntVar(&requests, "requests", app.DefaultBackfillRequests, "number of on-demand requests to attempt")
cmd.Flags().DurationVar(&wait, "wait", 60*time.Second, "time to wait for an on-demand response per request")
cmd.Flags().DurationVar(&idleExit, "idle-exit", 5*time.Second, "exit after being idle (after backfill requests)")
return cmd

View File

@ -31,6 +31,9 @@ func newMediaDownloadCmd(flags *rootFlags) *cobra.Command {
if chat == "" || id == "" {
return fmt.Errorf("--chat and --id are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -58,6 +58,9 @@ func runPresence(flags *rootFlags, to string, state types.ChatPresence, media st
if strings.TrimSpace(to) == "" {
return fmt.Errorf("--to is required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/cobra"
@ -22,6 +23,8 @@ type rootFlags struct {
asJSON bool
fullOutput bool
timeout time.Duration
readOnly bool
lockWait time.Duration
}
func execute(args []string) error {
@ -39,6 +42,8 @@ func execute(args []string) error {
rootCmd.PersistentFlags().BoolVar(&flags.asJSON, "json", false, "output JSON instead of human-readable text")
rootCmd.PersistentFlags().BoolVar(&flags.fullOutput, "full", false, "disable truncation in table output")
rootCmd.PersistentFlags().DurationVar(&flags.timeout, "timeout", 5*time.Minute, "command timeout (non-sync commands)")
rootCmd.PersistentFlags().DurationVar(&flags.lockWait, "lock-wait", 0, "wait for the store lock before failing (write commands)")
rootCmd.PersistentFlags().BoolVar(&flags.readOnly, "read-only", false, "reject commands that intentionally write WhatsApp or the local store (or set WACLI_READONLY=1)")
rootCmd.AddCommand(newVersionCmd())
rootCmd.AddCommand(newDoctorCmd(&flags))
@ -71,7 +76,7 @@ func newApp(ctx context.Context, flags *rootFlags, needLock bool, allowUnauthed
var lk *lock.Lock
if needLock {
var err error
lk, err = lock.Acquire(storeDir)
lk, err = lock.AcquireWithTimeout(ctx, storeDir, flags.lockWait)
if err != nil {
return nil, nil, err
}
@ -93,6 +98,28 @@ func newApp(ctx context.Context, flags *rootFlags, needLock bool, allowUnauthed
return a, lk, nil
}
func (f *rootFlags) isReadOnly() bool {
if f == nil {
return false
}
if f.readOnly {
return true
}
switch strings.ToLower(strings.TrimSpace(os.Getenv("WACLI_READONLY"))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func (f *rootFlags) requireWritable() error {
if f.isReadOnly() {
return fmt.Errorf("read-only mode: command would intentionally modify WhatsApp or the local store")
}
return nil
}
func withTimeout(ctx context.Context, flags *rootFlags) (context.Context, context.CancelFunc) {
if flags.timeout <= 0 {
return context.WithCancel(ctx)

26
cmd/wacli/root_test.go Normal file
View File

@ -0,0 +1,26 @@
package main
import (
"strings"
"testing"
)
func TestRootFlagsReadOnlyFlag(t *testing.T) {
flags := &rootFlags{readOnly: true}
if !flags.isReadOnly() {
t.Fatal("isReadOnly = false, want true")
}
err := flags.requireWritable()
if err == nil || !strings.Contains(err.Error(), "read-only mode") {
t.Fatalf("requireWritable error = %v", err)
}
}
func TestRootFlagsReadOnlyEnv(t *testing.T) {
t.Setenv("WACLI_READONLY", "yes")
if !(&rootFlags{}).isReadOnly() {
t.Fatal("isReadOnly = false, want true")
}
}

View File

@ -42,6 +42,9 @@ func newSendTextCmd(flags *rootFlags) *cobra.Command {
if to == "" || message == "" {
return fmt.Errorf("--to and --message are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -24,6 +24,9 @@ func newSendFileCmd(flags *rootFlags) *cobra.Command {
if to == "" || filePath == "" {
return fmt.Errorf("--to and --file are required")
}
if err := flags.requireWritable(); err != nil {
return err
}
ctx, cancel := withTimeout(context.Background(), flags)
defer cancel()

View File

@ -23,6 +23,9 @@ func newSyncCmd(flags *rootFlags) *cobra.Command {
Use: "sync",
Short: "Sync messages (requires prior auth; never shows QR)",
RunE: func(cmd *cobra.Command, args []string) error {
if err := flags.requireWritable(); err != nil {
return err
}
ctx, stop := signalContext()
defer stop()

View File

@ -22,6 +22,13 @@ type BackfillOptions struct {
IdleExit time.Duration
}
const (
DefaultBackfillCount = 50
DefaultBackfillRequests = 1
MaxBackfillCount = 500
MaxBackfillRequests = 100
)
type BackfillResult struct {
ChatJID string
RequestsSent int
@ -47,17 +54,9 @@ func (a *App) BackfillHistory(ctx context.Context, opts BackfillOptions) (Backfi
}
chatStr = chat.String()
if opts.Count <= 0 {
opts.Count = 50
}
if opts.Requests <= 0 {
opts.Requests = 1
}
if opts.WaitPerRequest <= 0 {
opts.WaitPerRequest = 60 * time.Second
}
if opts.IdleExit <= 0 {
opts.IdleExit = 5 * time.Second
opts = normalizeBackfillOptions(opts)
if err := validateBackfillOptions(opts); err != nil {
return BackfillResult{}, err
}
if err := a.EnsureAuthed(); err != nil {
@ -190,3 +189,29 @@ func (a *App) BackfillHistory(ctx context.Context, opts BackfillOptions) (Backfi
MessagesSynced: syncRes.MessagesStored,
}, nil
}
func normalizeBackfillOptions(opts BackfillOptions) BackfillOptions {
if opts.Count <= 0 {
opts.Count = DefaultBackfillCount
}
if opts.Requests <= 0 {
opts.Requests = DefaultBackfillRequests
}
if opts.WaitPerRequest <= 0 {
opts.WaitPerRequest = 60 * time.Second
}
if opts.IdleExit <= 0 {
opts.IdleExit = 5 * time.Second
}
return opts
}
func validateBackfillOptions(opts BackfillOptions) error {
if opts.Count > MaxBackfillCount {
return fmt.Errorf("--count must be <= %d (got %d)", MaxBackfillCount, opts.Count)
}
if opts.Requests > MaxBackfillRequests {
return fmt.Errorf("--requests must be <= %d (got %d)", MaxBackfillRequests, opts.Requests)
}
return nil
}

View File

@ -2,6 +2,7 @@ package app
import (
"context"
"strings"
"testing"
"time"
@ -80,6 +81,38 @@ func TestBackfillHistoryAddsOlderMessages(t *testing.T) {
}
}
func TestNormalizeBackfillOptions(t *testing.T) {
opts := normalizeBackfillOptions(BackfillOptions{})
if opts.Count != DefaultBackfillCount {
t.Fatalf("Count = %d, want %d", opts.Count, DefaultBackfillCount)
}
if opts.Requests != DefaultBackfillRequests {
t.Fatalf("Requests = %d, want %d", opts.Requests, DefaultBackfillRequests)
}
if opts.WaitPerRequest <= 0 || opts.IdleExit <= 0 {
t.Fatalf("durations must default positive: %+v", opts)
}
}
func TestValidateBackfillOptionsCapsWork(t *testing.T) {
err := validateBackfillOptions(BackfillOptions{
Count: MaxBackfillCount + 1,
Requests: DefaultBackfillRequests,
})
if err == nil || !strings.Contains(err.Error(), "--count") {
t.Fatalf("count error = %v", err)
}
err = validateBackfillOptions(BackfillOptions{
Count: DefaultBackfillCount,
Requests: MaxBackfillRequests + 1,
})
if err == nil || !strings.Contains(err.Error(), "--requests") {
t.Fatalf("requests error = %v", err)
}
}
func storeUpsertMessage(chatJID, id string, ts time.Time, text string) store.UpsertMessageParams {
return store.UpsertMessageParams{
ChatJID: chatJID,

View File

@ -1,6 +1,7 @@
package lock
import (
"context"
"fmt"
"os"
"path/filepath"
@ -44,6 +45,32 @@ func Acquire(storeDir string) (*Lock, error) {
return &Lock{path: path, f: f}, nil
}
func AcquireWithTimeout(ctx context.Context, storeDir string, wait time.Duration) (*Lock, error) {
if wait <= 0 {
return Acquire(storeDir)
}
deadline := time.NewTimer(wait)
defer deadline.Stop()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
var lastErr error
for {
lk, err := Acquire(storeDir)
if err == nil {
return lk, nil
}
lastErr = err
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-deadline.C:
return nil, fmt.Errorf("timed out waiting for store lock after %s: %w", wait, lastErr)
case <-ticker.C:
}
}
}
func (l *Lock) Release() error {
if l == nil || l.f == nil {
return nil

View File

@ -55,3 +55,20 @@ func TestLockBlocksOtherProcess(t *testing.T) {
t.Fatalf("expected helper to report locked; output=%q", strings.TrimSpace(got))
}
}
func TestAcquireWithTimeout(t *testing.T) {
dir := t.TempDir()
lk, err := Acquire(dir)
if err != nil {
t.Fatalf("acquire: %v", err)
}
defer lk.Release()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = AcquireWithTimeout(ctx, dir, 50*time.Millisecond)
if err == nil || !strings.Contains(err.Error(), "timed out waiting for store lock") {
t.Fatalf("AcquireWithTimeout error = %v", err)
}
}