feat: add command safety controls
This commit is contained in:
parent
86a43def24
commit
f6a94547d6
@ -23,6 +23,9 @@ func newAuthCmd(flags *rootFlags) *cobra.Command {
|
||||
Use: "auth",
|
||||
Short: "Authenticate with WhatsApp (QR) and bootstrap sync",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, stop := signalContext()
|
||||
defer stop()
|
||||
|
||||
@ -146,6 +149,9 @@ func newAuthLogoutCmd(flags *rootFlags) *cobra.Command {
|
||||
Use: "logout",
|
||||
Short: "Logout (invalidate session)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -119,6 +119,9 @@ func newContactsRefreshCmd(flags *rootFlags) *cobra.Command {
|
||||
Use: "refresh",
|
||||
Short: "Import contacts from whatsmeow store into local DB",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -138,6 +141,7 @@ func newContactsRefreshCmd(flags *rootFlags) *cobra.Command {
|
||||
|
||||
var count int
|
||||
for jid, info := range cs {
|
||||
jid = canonicalCLIJID(jid)
|
||||
_ = a.DB().UpsertContact(
|
||||
jid.String(),
|
||||
jid.User,
|
||||
@ -173,6 +177,9 @@ func newContactsAliasCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(jid) == "" || strings.TrimSpace(alias) == "" {
|
||||
return fmt.Errorf("--jid and --alias are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
a, lk, err := newApp(ctx, flags, false, false)
|
||||
@ -198,6 +205,9 @@ func newContactsAliasCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(jid) == "" {
|
||||
return fmt.Errorf("--jid is required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
a, lk, err := newApp(ctx, flags, false, false)
|
||||
@ -235,6 +245,9 @@ func newContactsTagsCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(jid) == "" || strings.TrimSpace(tag) == "" {
|
||||
return fmt.Errorf("--jid and --tag are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
a, lk, err := newApp(ctx, flags, false, false)
|
||||
@ -261,6 +274,9 @@ func newContactsTagsCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(jid) == "" || strings.TrimSpace(tag) == "" {
|
||||
return fmt.Errorf("--jid and --tag are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
a, lk, err := newApp(ctx, flags, false, false)
|
||||
|
||||
@ -82,6 +82,9 @@ func newGroupsInviteLinkRevokeCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(jidStr) == "" {
|
||||
return fmt.Errorf("--jid is required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -125,6 +128,9 @@ func newGroupsJoinCmd(flags *rootFlags) *cobra.Command {
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return fmt.Errorf("--code is required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -34,6 +34,9 @@ func newGroupsParticipantsActionCmd(flags *rootFlags, action string) *cobra.Comm
|
||||
if strings.TrimSpace(group) == "" || len(users) == 0 {
|
||||
return fmt.Errorf("--jid and at least one --user are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -5,6 +5,13 @@ import (
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
)
|
||||
|
||||
func canonicalCLIJID(jid types.JID) types.JID {
|
||||
if jid.Server == types.DefaultUserServer {
|
||||
return jid.ToNonAD()
|
||||
}
|
||||
return jid
|
||||
}
|
||||
|
||||
func persistGroupInfo(db *store.DB, info *types.GroupInfo) error {
|
||||
if info == nil {
|
||||
return nil
|
||||
@ -22,7 +29,7 @@ func persistGroupInfo(db *store.DB, info *types.GroupInfo) error {
|
||||
}
|
||||
ps = append(ps, store.GroupParticipant{
|
||||
GroupJID: info.JID.String(),
|
||||
UserJID: p.JID.String(),
|
||||
UserJID: canonicalCLIJID(p.JID).String(),
|
||||
Role: role,
|
||||
})
|
||||
}
|
||||
|
||||
@ -15,6 +15,9 @@ func newGroupsRefreshCmd(flags *rootFlags) *cobra.Command {
|
||||
Use: "refresh",
|
||||
Short: "Fetch joined groups (live) and update local DB",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -33,6 +33,9 @@ func newHistoryBackfillCmd(flags *rootFlags) *cobra.Command {
|
||||
if chat == "" {
|
||||
return fmt.Errorf("--chat is required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, stop := signalContext()
|
||||
defer stop()
|
||||
@ -70,8 +73,8 @@ func newHistoryBackfillCmd(flags *rootFlags) *cobra.Command {
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&chat, "chat", "", "chat JID")
|
||||
cmd.Flags().IntVar(&count, "count", 50, "number of messages to request per on-demand sync (recommended: 50)")
|
||||
cmd.Flags().IntVar(&requests, "requests", 1, "number of on-demand requests to attempt")
|
||||
cmd.Flags().IntVar(&count, "count", app.DefaultBackfillCount, "number of messages to request per on-demand sync")
|
||||
cmd.Flags().IntVar(&requests, "requests", app.DefaultBackfillRequests, "number of on-demand requests to attempt")
|
||||
cmd.Flags().DurationVar(&wait, "wait", 60*time.Second, "time to wait for an on-demand response per request")
|
||||
cmd.Flags().DurationVar(&idleExit, "idle-exit", 5*time.Second, "exit after being idle (after backfill requests)")
|
||||
return cmd
|
||||
|
||||
@ -31,6 +31,9 @@ func newMediaDownloadCmd(flags *rootFlags) *cobra.Command {
|
||||
if chat == "" || id == "" {
|
||||
return fmt.Errorf("--chat and --id are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -58,6 +58,9 @@ func runPresence(flags *rootFlags, to string, state types.ChatPresence, media st
|
||||
if strings.TrimSpace(to) == "" {
|
||||
return fmt.Errorf("--to is required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@ -22,6 +23,8 @@ type rootFlags struct {
|
||||
asJSON bool
|
||||
fullOutput bool
|
||||
timeout time.Duration
|
||||
readOnly bool
|
||||
lockWait time.Duration
|
||||
}
|
||||
|
||||
func execute(args []string) error {
|
||||
@ -39,6 +42,8 @@ func execute(args []string) error {
|
||||
rootCmd.PersistentFlags().BoolVar(&flags.asJSON, "json", false, "output JSON instead of human-readable text")
|
||||
rootCmd.PersistentFlags().BoolVar(&flags.fullOutput, "full", false, "disable truncation in table output")
|
||||
rootCmd.PersistentFlags().DurationVar(&flags.timeout, "timeout", 5*time.Minute, "command timeout (non-sync commands)")
|
||||
rootCmd.PersistentFlags().DurationVar(&flags.lockWait, "lock-wait", 0, "wait for the store lock before failing (write commands)")
|
||||
rootCmd.PersistentFlags().BoolVar(&flags.readOnly, "read-only", false, "reject commands that intentionally write WhatsApp or the local store (or set WACLI_READONLY=1)")
|
||||
|
||||
rootCmd.AddCommand(newVersionCmd())
|
||||
rootCmd.AddCommand(newDoctorCmd(&flags))
|
||||
@ -71,7 +76,7 @@ func newApp(ctx context.Context, flags *rootFlags, needLock bool, allowUnauthed
|
||||
var lk *lock.Lock
|
||||
if needLock {
|
||||
var err error
|
||||
lk, err = lock.Acquire(storeDir)
|
||||
lk, err = lock.AcquireWithTimeout(ctx, storeDir, flags.lockWait)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -93,6 +98,28 @@ func newApp(ctx context.Context, flags *rootFlags, needLock bool, allowUnauthed
|
||||
return a, lk, nil
|
||||
}
|
||||
|
||||
func (f *rootFlags) isReadOnly() bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
if f.readOnly {
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("WACLI_READONLY"))) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (f *rootFlags) requireWritable() error {
|
||||
if f.isReadOnly() {
|
||||
return fmt.Errorf("read-only mode: command would intentionally modify WhatsApp or the local store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func withTimeout(ctx context.Context, flags *rootFlags) (context.Context, context.CancelFunc) {
|
||||
if flags.timeout <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
|
||||
26
cmd/wacli/root_test.go
Normal file
26
cmd/wacli/root_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRootFlagsReadOnlyFlag(t *testing.T) {
|
||||
flags := &rootFlags{readOnly: true}
|
||||
|
||||
if !flags.isReadOnly() {
|
||||
t.Fatal("isReadOnly = false, want true")
|
||||
}
|
||||
err := flags.requireWritable()
|
||||
if err == nil || !strings.Contains(err.Error(), "read-only mode") {
|
||||
t.Fatalf("requireWritable error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootFlagsReadOnlyEnv(t *testing.T) {
|
||||
t.Setenv("WACLI_READONLY", "yes")
|
||||
|
||||
if !(&rootFlags{}).isReadOnly() {
|
||||
t.Fatal("isReadOnly = false, want true")
|
||||
}
|
||||
}
|
||||
@ -42,6 +42,9 @@ func newSendTextCmd(flags *rootFlags) *cobra.Command {
|
||||
if to == "" || message == "" {
|
||||
return fmt.Errorf("--to and --message are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -24,6 +24,9 @@ func newSendFileCmd(flags *rootFlags) *cobra.Command {
|
||||
if to == "" || filePath == "" {
|
||||
return fmt.Errorf("--to and --file are required")
|
||||
}
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := withTimeout(context.Background(), flags)
|
||||
defer cancel()
|
||||
|
||||
@ -23,6 +23,9 @@ func newSyncCmd(flags *rootFlags) *cobra.Command {
|
||||
Use: "sync",
|
||||
Short: "Sync messages (requires prior auth; never shows QR)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := flags.requireWritable(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, stop := signalContext()
|
||||
defer stop()
|
||||
|
||||
|
||||
@ -22,6 +22,13 @@ type BackfillOptions struct {
|
||||
IdleExit time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBackfillCount = 50
|
||||
DefaultBackfillRequests = 1
|
||||
MaxBackfillCount = 500
|
||||
MaxBackfillRequests = 100
|
||||
)
|
||||
|
||||
type BackfillResult struct {
|
||||
ChatJID string
|
||||
RequestsSent int
|
||||
@ -47,17 +54,9 @@ func (a *App) BackfillHistory(ctx context.Context, opts BackfillOptions) (Backfi
|
||||
}
|
||||
chatStr = chat.String()
|
||||
|
||||
if opts.Count <= 0 {
|
||||
opts.Count = 50
|
||||
}
|
||||
if opts.Requests <= 0 {
|
||||
opts.Requests = 1
|
||||
}
|
||||
if opts.WaitPerRequest <= 0 {
|
||||
opts.WaitPerRequest = 60 * time.Second
|
||||
}
|
||||
if opts.IdleExit <= 0 {
|
||||
opts.IdleExit = 5 * time.Second
|
||||
opts = normalizeBackfillOptions(opts)
|
||||
if err := validateBackfillOptions(opts); err != nil {
|
||||
return BackfillResult{}, err
|
||||
}
|
||||
|
||||
if err := a.EnsureAuthed(); err != nil {
|
||||
@ -190,3 +189,29 @@ func (a *App) BackfillHistory(ctx context.Context, opts BackfillOptions) (Backfi
|
||||
MessagesSynced: syncRes.MessagesStored,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeBackfillOptions(opts BackfillOptions) BackfillOptions {
|
||||
if opts.Count <= 0 {
|
||||
opts.Count = DefaultBackfillCount
|
||||
}
|
||||
if opts.Requests <= 0 {
|
||||
opts.Requests = DefaultBackfillRequests
|
||||
}
|
||||
if opts.WaitPerRequest <= 0 {
|
||||
opts.WaitPerRequest = 60 * time.Second
|
||||
}
|
||||
if opts.IdleExit <= 0 {
|
||||
opts.IdleExit = 5 * time.Second
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func validateBackfillOptions(opts BackfillOptions) error {
|
||||
if opts.Count > MaxBackfillCount {
|
||||
return fmt.Errorf("--count must be <= %d (got %d)", MaxBackfillCount, opts.Count)
|
||||
}
|
||||
if opts.Requests > MaxBackfillRequests {
|
||||
return fmt.Errorf("--requests must be <= %d (got %d)", MaxBackfillRequests, opts.Requests)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -80,6 +81,38 @@ func TestBackfillHistoryAddsOlderMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeBackfillOptions(t *testing.T) {
|
||||
opts := normalizeBackfillOptions(BackfillOptions{})
|
||||
|
||||
if opts.Count != DefaultBackfillCount {
|
||||
t.Fatalf("Count = %d, want %d", opts.Count, DefaultBackfillCount)
|
||||
}
|
||||
if opts.Requests != DefaultBackfillRequests {
|
||||
t.Fatalf("Requests = %d, want %d", opts.Requests, DefaultBackfillRequests)
|
||||
}
|
||||
if opts.WaitPerRequest <= 0 || opts.IdleExit <= 0 {
|
||||
t.Fatalf("durations must default positive: %+v", opts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBackfillOptionsCapsWork(t *testing.T) {
|
||||
err := validateBackfillOptions(BackfillOptions{
|
||||
Count: MaxBackfillCount + 1,
|
||||
Requests: DefaultBackfillRequests,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "--count") {
|
||||
t.Fatalf("count error = %v", err)
|
||||
}
|
||||
|
||||
err = validateBackfillOptions(BackfillOptions{
|
||||
Count: DefaultBackfillCount,
|
||||
Requests: MaxBackfillRequests + 1,
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "--requests") {
|
||||
t.Fatalf("requests error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func storeUpsertMessage(chatJID, id string, ts time.Time, text string) store.UpsertMessageParams {
|
||||
return store.UpsertMessageParams{
|
||||
ChatJID: chatJID,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -44,6 +45,32 @@ func Acquire(storeDir string) (*Lock, error) {
|
||||
return &Lock{path: path, f: f}, nil
|
||||
}
|
||||
|
||||
func AcquireWithTimeout(ctx context.Context, storeDir string, wait time.Duration) (*Lock, error) {
|
||||
if wait <= 0 {
|
||||
return Acquire(storeDir)
|
||||
}
|
||||
deadline := time.NewTimer(wait)
|
||||
defer deadline.Stop()
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastErr error
|
||||
for {
|
||||
lk, err := Acquire(storeDir)
|
||||
if err == nil {
|
||||
return lk, nil
|
||||
}
|
||||
lastErr = err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-deadline.C:
|
||||
return nil, fmt.Errorf("timed out waiting for store lock after %s: %w", wait, lastErr)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lock) Release() error {
|
||||
if l == nil || l.f == nil {
|
||||
return nil
|
||||
|
||||
@ -55,3 +55,20 @@ func TestLockBlocksOtherProcess(t *testing.T) {
|
||||
t.Fatalf("expected helper to report locked; output=%q", strings.TrimSpace(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireWithTimeout(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
lk, err := Acquire(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("acquire: %v", err)
|
||||
}
|
||||
defer lk.Release()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err = AcquireWithTimeout(ctx, dir, 50*time.Millisecond)
|
||||
if err == nil || !strings.Contains(err.Error(), "timed out waiting for store lock") {
|
||||
t.Fatalf("AcquireWithTimeout error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user