refactor: split sync event and idle loops
This commit is contained in:
parent
5f897ee277
commit
668d7e5762
@ -3,8 +3,6 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -12,7 +10,6 @@ import (
|
||||
"github.com/steipete/wacli/internal/store"
|
||||
"github.com/steipete/wacli/internal/wa"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
)
|
||||
|
||||
type SyncMode string
|
||||
@ -63,97 +60,10 @@ func (a *App) Sync(ctx context.Context, opts SyncOptions) (SyncResult, error) {
|
||||
enqueueMedia := func(chatJID, msgID string) {}
|
||||
if opts.DownloadMedia {
|
||||
mediaJobs = make(chan mediaJob, 512)
|
||||
enqueueMedia = func(chatJID, msgID string) {
|
||||
if strings.TrimSpace(chatJID) == "" || strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mediaJobs <- mediaJob{chatJID: chatJID, msgID: msgID}:
|
||||
default:
|
||||
// Avoid blocking the event handler.
|
||||
go func() {
|
||||
select {
|
||||
case mediaJobs <- mediaJob{chatJID: chatJID, msgID: msgID}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
enqueueMedia = newMediaEnqueuer(ctx, mediaJobs)
|
||||
}
|
||||
|
||||
var panicCount atomic.Int64
|
||||
handlerID := a.wa.AddEventHandler(func(evt interface{}) {
|
||||
// Recover from panics so unexpected message structures do not
|
||||
// crash the entire process (#52). Log a stack trace, the event
|
||||
// type, and a running counter so recoveries do not go silently
|
||||
// unnoticed (#178).
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
n := panicCount.Add(1)
|
||||
fmt.Fprintf(os.Stderr, "\nevent handler panic (recovered, total=%d) event=%T: %v\n%s\n",
|
||||
n, evt, r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
|
||||
switch v := evt.(type) {
|
||||
case *events.Message:
|
||||
pm := wa.ParseLiveMessage(v)
|
||||
if pm.ReactionToID != "" && pm.ReactionEmoji == "" && v.Message != nil && v.Message.GetEncReactionMessage() != nil {
|
||||
if reaction, err := a.wa.DecryptReaction(ctx, v); err == nil && reaction != nil {
|
||||
pm.ReactionEmoji = reaction.GetText()
|
||||
if pm.ReactionToID == "" {
|
||||
if key := reaction.GetKey(); key != nil {
|
||||
pm.ReactionToID = key.GetID()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := a.storeParsedMessage(ctx, pm); err == nil {
|
||||
messagesStored.Add(1)
|
||||
}
|
||||
if opts.DownloadMedia && pm.Media != nil && pm.ID != "" {
|
||||
enqueueMedia(pm.Chat.String(), pm.ID)
|
||||
}
|
||||
if messagesStored.Load()%25 == 0 {
|
||||
fmt.Fprintf(os.Stderr, "\rSynced %d messages...", messagesStored.Load())
|
||||
}
|
||||
case *events.HistorySync:
|
||||
fmt.Fprintf(os.Stderr, "\nProcessing history sync (%d conversations)...\n", len(v.Data.Conversations))
|
||||
for _, conv := range v.Data.Conversations {
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
chatID := strings.TrimSpace(conv.GetID())
|
||||
if chatID == "" {
|
||||
continue
|
||||
}
|
||||
for _, m := range conv.Messages {
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
if m.Message == nil {
|
||||
continue
|
||||
}
|
||||
pm := wa.ParseHistoryMessage(chatID, m.Message)
|
||||
if pm.ID == "" || pm.Chat.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if err := a.storeParsedMessage(ctx, pm); err == nil {
|
||||
messagesStored.Add(1)
|
||||
}
|
||||
if opts.DownloadMedia && pm.Media != nil && pm.ID != "" {
|
||||
enqueueMedia(pm.Chat.String(), pm.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\rSynced %d messages...", messagesStored.Load())
|
||||
case *events.Connected:
|
||||
fmt.Fprintln(os.Stderr, "\nConnected.")
|
||||
case *events.Disconnected:
|
||||
fmt.Fprintln(os.Stderr, "\nDisconnected.")
|
||||
select {
|
||||
case disconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
handlerID := a.addSyncEventHandler(ctx, opts, &messagesStored, &lastEvent, disconnected, enqueueMedia)
|
||||
defer a.wa.RemoveEventHandler(handlerID)
|
||||
|
||||
if err := a.Connect(ctx, opts.AllowQR, opts.OnQRCode); err != nil {
|
||||
@ -183,63 +93,10 @@ func (a *App) Sync(ctx context.Context, opts SyncOptions) (SyncResult, error) {
|
||||
}
|
||||
|
||||
if opts.Mode == SyncModeFollow {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "\nStopping sync.")
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
case <-disconnected:
|
||||
fmt.Fprintln(os.Stderr, "Reconnecting...")
|
||||
if err := a.reconnect(ctx, opts.MaxReconnect); err != nil {
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return a.runSyncFollow(ctx, opts.MaxReconnect, &messagesStored, disconnected)
|
||||
}
|
||||
|
||||
// Bootstrap/once: exit after idle.
|
||||
poll := 250 * time.Millisecond
|
||||
if opts.IdleExit >= 2*time.Second {
|
||||
poll = 1 * time.Second
|
||||
}
|
||||
ticker := time.NewTicker(poll)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "\nStopping sync.")
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
case <-disconnected:
|
||||
fmt.Fprintln(os.Stderr, "Reconnecting...")
|
||||
if err := a.reconnect(ctx, opts.MaxReconnect); err != nil {
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, err
|
||||
}
|
||||
case <-ticker.C:
|
||||
last := time.Unix(0, lastEvent.Load())
|
||||
if time.Since(last) >= opts.IdleExit {
|
||||
fmt.Fprintf(os.Stderr, "\nIdle for %s, exiting.\n", opts.IdleExit)
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconnect wraps ReconnectWithBackoff with an optional deadline.
|
||||
// If maxDuration is positive, reconnection gives up after that long.
|
||||
// A zero or negative value means retry indefinitely (until ctx is cancelled).
|
||||
func (a *App) reconnect(ctx context.Context, maxDuration time.Duration) error {
|
||||
rctx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if maxDuration > 0 {
|
||||
rctx, cancel = context.WithTimeout(ctx, maxDuration)
|
||||
defer cancel()
|
||||
}
|
||||
err := a.wa.ReconnectWithBackoff(rctx, 2*time.Second, 30*time.Second)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
// Deadline hit but parent context is still alive — we gave up, not the user.
|
||||
return fmt.Errorf("could not reconnect after %s: %w", maxDuration, err)
|
||||
}
|
||||
return err
|
||||
return a.runSyncUntilIdle(ctx, opts.IdleExit, opts.MaxReconnect, &messagesStored, &lastEvent, disconnected)
|
||||
}
|
||||
|
||||
func chatKind(chat types.JID) string {
|
||||
|
||||
115
internal/app/sync_events.go
Normal file
115
internal/app/sync_events.go
Normal file
@ -0,0 +1,115 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/wacli/internal/wa"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
)
|
||||
|
||||
func newMediaEnqueuer(ctx context.Context, jobs chan<- mediaJob) func(chatJID, msgID string) {
|
||||
return func(chatJID, msgID string) {
|
||||
if strings.TrimSpace(chatJID) == "" || strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case jobs <- mediaJob{chatJID: chatJID, msgID: msgID}:
|
||||
default:
|
||||
// Avoid blocking the event handler.
|
||||
go func() {
|
||||
select {
|
||||
case jobs <- mediaJob{chatJID: chatJID, msgID: msgID}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) addSyncEventHandler(ctx context.Context, opts SyncOptions, messagesStored, lastEvent *atomic.Int64, disconnected chan<- struct{}, enqueueMedia func(string, string)) uint32 {
|
||||
var panicCount atomic.Int64
|
||||
return a.wa.AddEventHandler(func(evt interface{}) {
|
||||
// Recover from panics so unexpected message structures do not crash the
|
||||
// process. Include event type, stack trace, and a running counter.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
n := panicCount.Add(1)
|
||||
fmt.Fprintf(os.Stderr, "\nevent handler panic (recovered, total=%d) event=%T: %v\n%s\n",
|
||||
n, evt, r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
|
||||
switch v := evt.(type) {
|
||||
case *events.Message:
|
||||
a.handleLiveSyncMessage(ctx, opts, v, messagesStored, enqueueMedia)
|
||||
case *events.HistorySync:
|
||||
a.handleHistorySync(ctx, opts, v, messagesStored, lastEvent, enqueueMedia)
|
||||
case *events.Connected:
|
||||
fmt.Fprintln(os.Stderr, "\nConnected.")
|
||||
case *events.Disconnected:
|
||||
fmt.Fprintln(os.Stderr, "\nDisconnected.")
|
||||
select {
|
||||
case disconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (a *App) handleLiveSyncMessage(ctx context.Context, opts SyncOptions, v *events.Message, messagesStored *atomic.Int64, enqueueMedia func(string, string)) {
|
||||
pm := wa.ParseLiveMessage(v)
|
||||
if pm.ReactionToID != "" && pm.ReactionEmoji == "" && v.Message != nil && v.Message.GetEncReactionMessage() != nil {
|
||||
if reaction, err := a.wa.DecryptReaction(ctx, v); err == nil && reaction != nil {
|
||||
pm.ReactionEmoji = reaction.GetText()
|
||||
if pm.ReactionToID == "" {
|
||||
if key := reaction.GetKey(); key != nil {
|
||||
pm.ReactionToID = key.GetID()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := a.storeParsedMessage(ctx, pm); err == nil {
|
||||
messagesStored.Add(1)
|
||||
}
|
||||
if opts.DownloadMedia && pm.Media != nil && pm.ID != "" {
|
||||
enqueueMedia(pm.Chat.String(), pm.ID)
|
||||
}
|
||||
if messagesStored.Load()%25 == 0 {
|
||||
fmt.Fprintf(os.Stderr, "\rSynced %d messages...", messagesStored.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) handleHistorySync(ctx context.Context, opts SyncOptions, v *events.HistorySync, messagesStored, lastEvent *atomic.Int64, enqueueMedia func(string, string)) {
|
||||
fmt.Fprintf(os.Stderr, "\nProcessing history sync (%d conversations)...\n", len(v.Data.Conversations))
|
||||
for _, conv := range v.Data.Conversations {
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
chatID := strings.TrimSpace(conv.GetID())
|
||||
if chatID == "" {
|
||||
continue
|
||||
}
|
||||
for _, m := range conv.Messages {
|
||||
lastEvent.Store(time.Now().UTC().UnixNano())
|
||||
if m.Message == nil {
|
||||
continue
|
||||
}
|
||||
pm := wa.ParseHistoryMessage(chatID, m.Message)
|
||||
if pm.ID == "" || pm.Chat.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if err := a.storeParsedMessage(ctx, pm); err == nil {
|
||||
messagesStored.Add(1)
|
||||
}
|
||||
if opts.DownloadMedia && pm.Media != nil && pm.ID != "" {
|
||||
enqueueMedia(pm.Chat.String(), pm.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\rSynced %d messages...", messagesStored.Load())
|
||||
}
|
||||
68
internal/app/sync_idle.go
Normal file
68
internal/app/sync_idle.go
Normal file
@ -0,0 +1,68 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (a *App) runSyncFollow(ctx context.Context, maxReconnect time.Duration, messagesStored *atomic.Int64, disconnected <-chan struct{}) (SyncResult, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "\nStopping sync.")
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
case <-disconnected:
|
||||
fmt.Fprintln(os.Stderr, "Reconnecting...")
|
||||
if err := a.reconnect(ctx, maxReconnect); err != nil {
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) runSyncUntilIdle(ctx context.Context, idleExit, maxReconnect time.Duration, messagesStored, lastEvent *atomic.Int64, disconnected <-chan struct{}) (SyncResult, error) {
|
||||
poll := 250 * time.Millisecond
|
||||
if idleExit >= 2*time.Second {
|
||||
poll = 1 * time.Second
|
||||
}
|
||||
ticker := time.NewTicker(poll)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintln(os.Stderr, "\nStopping sync.")
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
case <-disconnected:
|
||||
fmt.Fprintln(os.Stderr, "Reconnecting...")
|
||||
if err := a.reconnect(ctx, maxReconnect); err != nil {
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, err
|
||||
}
|
||||
case <-ticker.C:
|
||||
last := time.Unix(0, lastEvent.Load())
|
||||
if time.Since(last) >= idleExit {
|
||||
fmt.Fprintf(os.Stderr, "\nIdle for %s, exiting.\n", idleExit)
|
||||
return SyncResult{MessagesStored: messagesStored.Load()}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconnect wraps ReconnectWithBackoff with an optional deadline. If maxDuration
|
||||
// is positive, reconnection gives up after that long; otherwise it retries until
|
||||
// ctx is cancelled.
|
||||
func (a *App) reconnect(ctx context.Context, maxDuration time.Duration) error {
|
||||
rctx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if maxDuration > 0 {
|
||||
rctx, cancel = context.WithTimeout(ctx, maxDuration)
|
||||
defer cancel()
|
||||
}
|
||||
err := a.wa.ReconnectWithBackoff(rctx, 2*time.Second, 30*time.Second)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
return fmt.Errorf("could not reconnect after %s: %w", maxDuration, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user