diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a982a3..a7e292b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ All notable changes to `discrawl` will be documented in this file. ### Fixes - `discrawl` now handles SIGINT/SIGTERM by canceling active sync/import contexts so large SQLite and FTS writes can roll back and close cleanly instead of being terminated mid-transaction. +- `discrawl sync` now keeps Git snapshot refreshes explicit by default; use `--update=auto` or `--update=force` when you want a sync run to pull/import the shared snapshot before live Discord or desktop-cache deltas. +- Snapshot imports now emit phase/table/file progress and keep the sync lock file updated with the active phase, making long update/import runs diagnosable instead of looking hung. +- Recent-message scans are backed by a plain `messages(created_at, id)` index so archive freshness and short-window analysis queries avoid full-table scans. ### Maintenance diff --git a/README.md b/README.md index cb0cde7..657ab8d 100644 --- a/README.md +++ b/README.md @@ -173,15 +173,20 @@ discrawl init --db ~/data/discrawl.db Refreshes SQLite from one or both archive sources. -By default, `sync` runs both sources: +By default, `sync` runs both live/local sources and does not import the Git snapshot first: - Discord bot-token sync for bot-visible guild data - local Discord Desktop cache import for classifiable cached messages and proven DMs +Use `discrawl update` when you want to pull/import the shared Git snapshot. If you intentionally want a sync run to import the snapshot before live deltas, pass `--update=auto` to import only when stale or `--update=force` to pull/import before syncing. `--no-update` is accepted as an explicit no-op alias for the default. + Run one explicit `--full` pass when you want a complete historical guild archive. Use plain `sync` afterward for frequent latest-message and desktop-cache refreshes. ```bash discrawl sync +discrawl sync --update=auto +discrawl sync --update=force +discrawl sync --no-update discrawl sync --full discrawl sync --full --all discrawl sync --guild 123456789012345678 @@ -207,7 +212,8 @@ Bot sync modes: | Command | Use when | Behavior | | --- | --- | --- | -| `discrawl sync` | routine refresh | imports any stale Git snapshot first, skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor | +| `discrawl sync` | routine refresh | skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor | +| `discrawl sync --update=auto` | hybrid Git/live refresh | imports a stale Git snapshot first, then runs the routine live refresh | | `discrawl sync --all-channels` | repair pass | broad incremental sweep across every stored channel/thread, including archived threads | | `discrawl sync --full` | historical backfill | crawls older history until channels are complete; can take a long time on large servers | @@ -218,7 +224,7 @@ Bot sync modes: `--latest-only` is still accepted for explicit latest-only runs; it is now the default for untargeted `sync`. Use `--all-channels` to opt out of the fast default without doing a full historical crawl. When `--channels` includes a forum channel id, `discrawl` expands that forum's threads and syncs their messages as part of the targeted run. `--since` limits initial history/bootstrap and full-history backfill to messages at or after the given RFC3339 timestamp. It does not mark older history as complete, so a later `sync --full` without `--since` can continue the backfill. -Long runs now emit periodic progress logs to stderr so large backfills do not look hung. +Long runs now emit periodic progress logs to stderr so large backfills and Git snapshot imports do not look hung. If in-flight channels stop completing for a while, `discrawl` now emits `message sync waiting` heartbeat logs with the oldest active channel, per-channel page activity, and skip/defer counters, and every run ends with a `message sync finished` summary. Each channel crawl also has a bounded runtime budget, so a pathological channel is deferred and retried on the next sync instead of pinning a worker forever. Full sync member refresh is best-effort and currently gives up after five minutes without a caller-supplied deadline, so message sync completion is not held hostage by a slow guild member crawl. @@ -454,9 +460,9 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive. discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git ``` -Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually. +Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually. `discrawl sync` does not auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast. -Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass. +Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass. Git snapshots publish non-DM archive tables by default. Embedding queue state stays local to each machine, and Git-only readers can use FTS immediately without an embedding provider. diff --git a/internal/cli/admin_commands.go b/internal/cli/admin_commands.go index f5058f4..ddbb82f 100644 --- a/internal/cli/admin_commands.go +++ b/internal/cli/admin_commands.go @@ -113,9 +113,19 @@ func (r *runtime) runSync(args []string) error { latestOnly := fs.Bool("latest-only", false, "") guildsFlag := fs.String("guilds", "", "") guildFlag := fs.String("guild", "", "") + updateMode := fs.String("update", "", "") + noUpdate := fs.Bool("no-update", false, "") if err := fs.Parse(args); err != nil { return usageErr(err) } + if *noUpdate && strings.TrimSpace(*updateMode) != "" && !strings.EqualFold(strings.TrimSpace(*updateMode), string(shareUpdateNever)) { + return usageErr(errors.New("use either --no-update or --update, not both")) + } + if strings.TrimSpace(*updateMode) != "" { + if _, err := parseShareUpdateMode(*updateMode); err != nil { + return usageErr(err) + } + } sources, err := parseSyncSources(*source) if err != nil { return usageErr(err) @@ -151,6 +161,7 @@ func (r *runtime) runSync(args []string) error { func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) error { var apiStats *syncer.SyncStats if sources.discord { + r.setSyncLockPhase("discord sync") shouldClose := r.client == nil if err := r.ensureDiscordServices(); err != nil { return err @@ -166,6 +177,7 @@ func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) er } var wiretapStats *discorddesktop.Stats if sources.wiretap { + r.setSyncLockPhase("wiretap import") stats, err := discorddesktop.Import(r.ctx, r.store, discorddesktop.Options{ Path: r.cfg.Desktop.Path, MaxFileBytes: r.cfg.Desktop.MaxFileBytes, diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 14be40e..67a3b6f 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -90,23 +90,24 @@ func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error { } type runtime struct { - ctx context.Context - configPath string - cfg config.Config - stdout io.Writer - stderr io.Writer - json bool - plain bool - logger *slog.Logger - store *store.Store - client discordClient - syncer syncService - dbLockHeld bool - openStore func(context.Context, string) (*store.Store, error) - newDiscord func(config.Config) (discordClient, error) - newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService - newEmbed func(config.EmbeddingsConfig) (embed.Provider, error) - now func() time.Time + ctx context.Context + configPath string + cfg config.Config + stdout io.Writer + stderr io.Writer + json bool + plain bool + logger *slog.Logger + store *store.Store + client discordClient + syncer syncService + dbLockHeld bool + lockStarted time.Time + openStore func(context.Context, string) (*store.Store, error) + newDiscord func(config.Config) (discordClient, error) + newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService + newEmbed func(config.EmbeddingsConfig) (embed.Provider, error) + now func() time.Time } type discordClient interface { @@ -131,7 +132,11 @@ func (r *runtime) dispatch(rest []string) error { case "init": return r.runInit(rest[1:]) case "sync": - return r.withLocalStoreLocked(true, func() error { return r.runSync(rest[1:]) }) + updateMode, err := syncShareUpdateMode(rest[1:]) + if err != nil { + return usageErr(err) + } + return r.withLocalStoreUpdateLocked(updateMode, true, func() error { return r.runSync(rest[1:]) }) case "tail": return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) }) case "wiretap": @@ -187,14 +192,18 @@ func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error { } func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error { - return r.withLocalStoreDefaultLocked(autoShareUpdate, true, fn) + return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn) } func (r *runtime) withLocalStoreDefault(autoShareUpdate bool, fn func() error) error { - return r.withLocalStoreDefaultLocked(autoShareUpdate, false, fn) + return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), false, fn) } func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn func() error) error { + return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), lockDB, fn) +} + +func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error { cfg, err := config.Load(r.configPath) if err != nil { if !errors.Is(err, os.ErrNotExist) { @@ -215,13 +224,13 @@ func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn f r.cfg = cfg if lockDB { return r.withSyncLock(func() error { - return r.openLocalStore(dbPath, autoShareUpdate, fn) + return r.openLocalStore(dbPath, updateMode, fn) }) } - return r.openLocalStore(dbPath, autoShareUpdate, fn) + return r.openLocalStore(dbPath, updateMode, fn) } -func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func() error) error { +func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error { storeFactory := r.openStore if storeFactory == nil { storeFactory = store.Open @@ -232,8 +241,8 @@ func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func() return dbErr(err) } defer func() { _ = r.store.Close() }() - if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" { - if err := r.autoUpdateShare(); err != nil { + if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" { + if err := r.autoUpdateShare(updateMode); err != nil { return err } } @@ -245,6 +254,10 @@ func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func() } func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error { + return r.withServicesUpdateLocked(withDiscord, boolShareUpdateMode(autoShareUpdate), lockDB, fn) +} + +func (r *runtime) withServicesUpdateLocked(withDiscord bool, updateMode shareUpdateMode, lockDB bool, fn func() error) error { cfg, err := config.Load(r.configPath) if err != nil { return configErr(err) @@ -259,13 +272,13 @@ func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bo r.cfg = cfg if lockDB { return r.withSyncLock(func() error { - return r.openServices(dbPath, withDiscord, autoShareUpdate, fn) + return r.openServices(dbPath, withDiscord, updateMode, fn) }) } - return r.openServices(dbPath, withDiscord, autoShareUpdate, fn) + return r.openServices(dbPath, withDiscord, updateMode, fn) } -func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool, fn func() error) error { +func (r *runtime) openServices(dbPath string, withDiscord bool, updateMode shareUpdateMode, fn func() error) error { storeFactory := r.openStore if storeFactory == nil { storeFactory = store.Open @@ -276,8 +289,8 @@ func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool, return dbErr(err) } defer func() { _ = r.store.Close() }() - if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" { - if err := r.autoUpdateShare(); err != nil { + if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" { + if err := r.autoUpdateShare(updateMode); err != nil { return err } } @@ -321,24 +334,27 @@ func (r *runtime) ensureDiscordServices() error { return nil } -func (r *runtime) autoUpdateShare() error { - if !r.cfg.ShareEnabled() || !r.cfg.Share.AutoUpdate { +func (r *runtime) autoUpdateShare(mode shareUpdateMode) error { + if !r.cfg.ShareEnabled() || (mode == shareUpdateConfigured && !r.cfg.Share.AutoUpdate) { return nil } staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter) if err != nil { return configErr(fmt.Errorf("invalid share.stale_after: %w", err)) } - if !share.NeedsImport(r.ctx, r.store, staleAfter) { + if mode != shareUpdateForce && !share.NeedsImport(r.ctx, r.store, staleAfter) { return nil } opts, err := r.shareOptions() if err != nil { return err } + r.setSyncLockPhase("share pull") + r.logger.Info("share update pulling", "repo_path", opts.RepoPath, "remote", opts.Remote) if err := share.Pull(r.ctx, opts); err != nil { return err } + r.setSyncLockPhase("share import") _, _, err = share.ImportIfChanged(r.ctx, r.store, opts) if errors.Is(err, share.ErrNoManifest) { return nil @@ -355,5 +371,6 @@ func (r *runtime) shareOptions() (share.Options, error) { RepoPath: repoPath, Remote: r.cfg.Share.Remote, Branch: r.cfg.Share.Branch, + Progress: r.shareProgress, }, nil } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index dcb7913..d42f94f 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -600,7 +600,7 @@ func TestShareUpdateImportsNewRemoteSnapshot(t *testing.T) { require.Contains(t, out.String(), "newer git snapshot arrived") } -func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) { +func TestSyncSkipsGitShareByDefaultAndCanImportBeforeLiveDiscord(t *testing.T) { ctx := context.Background() dir := t.TempDir() remoteRepo := filepath.Join(dir, "remote.git") @@ -643,17 +643,33 @@ func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) { } require.NoError(t, rt.dispatch([]string{"sync", "--all"})) - require.True(t, hybrid.sawGitMessage) + require.False(t, hybrid.sawGitMessage) reader, err := store.Open(ctx, cfg.DBPath) require.NoError(t, err) - defer func() { _ = reader.Close() }() rows, err := reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true}) require.NoError(t, err) contents := make([]string, 0, len(rows)) for _, row := range rows { contents = append(contents, row.Content) } + require.NotContains(t, contents, "automatic updates work") + require.Contains(t, contents, "live discord filled the delta") + require.NoError(t, reader.Close()) + + hybrid.sawGitMessage = false + require.NoError(t, rt.dispatch([]string{"sync", "--all", "--update=auto"})) + require.True(t, hybrid.sawGitMessage) + + reader, err = store.Open(ctx, cfg.DBPath) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + rows, err = reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true}) + require.NoError(t, err) + contents = contents[:0] + for _, row := range rows { + contents = append(contents, row.Content) + } require.Contains(t, contents, "automatic updates work") require.Contains(t, contents, "live discord filled the delta") } @@ -1543,6 +1559,17 @@ func TestHelpers(t *testing.T) { require.Equal(t, []string{"a", "b"}, csvList("a,b,a")) require.Equal(t, "x", (&cliError{code: 2, err: assertErr("x")}).Error()) + mode, err := syncShareUpdateMode([]string{"--all"}) + require.NoError(t, err) + require.Equal(t, shareUpdateNever, mode) + mode, err = syncShareUpdateMode([]string{"--update=auto"}) + require.NoError(t, err) + require.Equal(t, shareUpdateAuto, mode) + mode, err = syncShareUpdateMode([]string{"--update", "force"}) + require.NoError(t, err) + require.Equal(t, shareUpdateForce, mode) + _, err = syncShareUpdateMode([]string{"--update"}) + require.Error(t, err) require.Equal(t, 2, ExitCode(usageErr(assertErr("x")))) require.Equal(t, 4, ExitCode(authErr(assertErr("x")))) require.Equal(t, 5, ExitCode(dbErr(assertErr("x")))) @@ -1926,6 +1953,8 @@ func TestCommandUsageErrors(t *testing.T) { require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "-1"}))) require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "1", "--since", "2026-03-01T00:00:00Z"}))) require.Equal(t, 2, ExitCode(rt.runSync([]string{"--all", "--guild", "g1"}))) + require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update", "bogus"}))) + require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update=force", "--no-update"}))) require.Equal(t, 2, ExitCode(rt.runChannels(nil))) require.Equal(t, 2, ExitCode(rt.runStatus([]string{"extra"}))) require.NoError(t, (&runtime{stdout: &bytes.Buffer{}}).runDoctor(nil)) diff --git a/internal/cli/share_commands.go b/internal/cli/share_commands.go index ca73aad..176c67c 100644 --- a/internal/cli/share_commands.go +++ b/internal/cli/share_commands.go @@ -136,13 +136,15 @@ func (r *runtime) runSubscribe(args []string) error { if err != nil { return configErr(err) } - opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch} + opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch, Progress: r.shareProgress} if *withEmbeddings { applyEmbeddingShareOptions(&opts, cfg) } + r.setSyncLockPhase("share pull") if err := share.Pull(r.ctx, opts); err != nil { return err } + r.setSyncLockPhase("share import") manifest, imported, err := share.ImportIfChanged(r.ctx, s, opts) if err != nil { return err @@ -176,12 +178,15 @@ func (r *runtime) runUpdate(args []string) error { if err != nil { return err } + opts.Progress = r.shareProgress if *withEmbeddings { applyEmbeddingShareOptions(&opts, r.cfg) } + r.setSyncLockPhase("share pull") if err := share.Pull(r.ctx, opts); err != nil { return err } + r.setSyncLockPhase("share import") manifest, imported, err := share.ImportIfChanged(r.ctx, r.store, opts) if err != nil { return err diff --git a/internal/cli/share_update.go b/internal/cli/share_update.go new file mode 100644 index 0000000..078ebb0 --- /dev/null +++ b/internal/cli/share_update.go @@ -0,0 +1,110 @@ +package cli + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/steipete/discrawl/internal/share" +) + +type shareUpdateMode string + +const ( + shareUpdateConfigured shareUpdateMode = "configured" + shareUpdateAuto shareUpdateMode = "auto" + shareUpdateNever shareUpdateMode = "never" + shareUpdateForce shareUpdateMode = "force" +) + +func boolShareUpdateMode(enabled bool) shareUpdateMode { + if enabled { + return shareUpdateConfigured + } + return shareUpdateNever +} + +func parseShareUpdateMode(raw string) (shareUpdateMode, error) { + switch shareUpdateMode(strings.ToLower(strings.TrimSpace(raw))) { + case "", shareUpdateAuto: + return shareUpdateAuto, nil + case shareUpdateNever: + return shareUpdateNever, nil + case shareUpdateForce: + return shareUpdateForce, nil + default: + return "", fmt.Errorf("invalid --update %q; use auto, never, or force", raw) + } +} + +func syncShareUpdateMode(args []string) (shareUpdateMode, error) { + mode := shareUpdateNever + sawNoUpdate := false + sawUpdate := false + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case arg == "--no-update": + sawNoUpdate = true + mode = shareUpdateNever + case arg == "--update": + if i+1 >= len(args) { + return "", errors.New("--update requires auto, never, or force") + } + parsed, err := parseShareUpdateMode(args[i+1]) + if err != nil { + return "", err + } + sawUpdate = true + mode = parsed + i++ + case strings.HasPrefix(arg, "--update="): + parsed, err := parseShareUpdateMode(strings.TrimPrefix(arg, "--update=")) + if err != nil { + return "", err + } + sawUpdate = true + mode = parsed + } + } + if sawNoUpdate && sawUpdate && mode != shareUpdateNever { + return "", errors.New("use either --no-update or --update, not both") + } + return mode, nil +} + +func (r *runtime) shareProgress(progress share.ImportProgress) { + if progress.Phase == "" { + return + } + phase := "share " + progress.Phase + if progress.Table != "" { + phase += " " + progress.Table + } + if progress.File != "" { + phase += " " + progress.File + } + r.setSyncLockPhase(phase) + attrs := []any{"phase", progress.Phase} + if progress.Table != "" { + attrs = append(attrs, "table", progress.Table) + } + if progress.Rows != 0 { + attrs = append(attrs, "rows", progress.Rows) + } + if progress.TotalRows != 0 { + attrs = append(attrs, "total_rows", progress.TotalRows) + } + if progress.File != "" { + attrs = append(attrs, "file", progress.File, "file_index", progress.FileIndex, "file_count", progress.FileCount) + } + r.logger.Info("share import progress", attrs...) +} + +func (r *runtime) nowUTC() time.Time { + if r.now != nil { + return r.now().UTC() + } + return time.Now().UTC() +} diff --git a/internal/cli/sync_lock.go b/internal/cli/sync_lock.go index f42a144..16f5ac6 100644 --- a/internal/cli/sync_lock.go +++ b/internal/cli/sync_lock.go @@ -3,7 +3,10 @@ package cli import ( "context" "fmt" + "os" "path/filepath" + "strings" + "time" "github.com/steipete/discrawl/internal/config" ) @@ -21,13 +24,37 @@ func (r *runtime) withSyncLock(fn func() error) error { return err } r.dbLockHeld = true + r.lockStarted = r.nowUTC() + r.setSyncLockPhase("locked") defer func() { r.dbLockHeld = false + r.lockStarted = time.Time{} _ = release() }() return fn() } +func (r *runtime) setSyncLockPhase(phase string) { + if !r.dbLockHeld { + return + } + path, err := r.syncLockPath() + if err != nil { + return + } + started := r.lockStarted + if started.IsZero() { + started = r.nowUTC() + } + body := fmt.Sprintf("pid=%d\nstarted_at=%s\nupdated_at=%s\nphase=%s\n", + os.Getpid(), + started.Format(time.RFC3339Nano), + r.nowUTC().Format(time.RFC3339Nano), + phase, + ) + _ = os.WriteFile(path, []byte(body), 0o600) +} + func (r *runtime) syncLockPath() (string, error) { dbPath, err := config.ExpandPath(r.cfg.DBPath) if err != nil { @@ -38,6 +65,12 @@ func (r *runtime) syncLockPath() (string, error) { func syncLockErr(ctx context.Context, path string) error { if ctx.Err() != nil { + if body, err := os.ReadFile(path); err == nil { + details := strings.TrimSpace(string(body)) + if details != "" { + return fmt.Errorf("wait for sync lock %s (%s): %w", path, strings.ReplaceAll(details, "\n", ", "), ctx.Err()) + } + } return fmt.Errorf("wait for sync lock %s: %w", path, ctx.Err()) } return nil diff --git a/internal/share/share.go b/internal/share/share.go index 41c3b57..5f9b2d1 100644 --- a/internal/share/share.go +++ b/internal/share/share.go @@ -52,6 +52,17 @@ type Options struct { EmbeddingProvider string EmbeddingModel string EmbeddingInputVersion string + Progress func(ImportProgress) +} + +type ImportProgress struct { + Phase string + Table string + File string + FileIndex int + FileCount int + Rows int + TotalRows int } type Manifest struct { @@ -221,6 +232,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) if err != nil { return Manifest{}, err } + opts.reportProgress(ImportProgress{Phase: "start", TotalRows: manifestRowCount(manifest)}) restorePragmas, err := applyImportPragmas(ctx, s.DB()) if err != nil { return Manifest{}, err @@ -242,11 +254,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) } }() for _, table := range []string{"message_fts", "member_fts"} { + opts.reportProgress(ImportProgress{Phase: "drop_fts", Table: table}) if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil { return Manifest{}, fmt.Errorf("drop %s: %w", table, err) } } for _, table := range slices.Backward(SnapshotTables) { + opts.reportProgress(ImportProgress{Phase: "clear", Table: table}) query, args := snapshotDeleteQuery(table) if _, err := tx.ExecContext(ctx, query, args...); err != nil { return Manifest{}, fmt.Errorf("clear %s: %w", table, err) @@ -256,10 +270,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) if err := ctx.Err(); err != nil { return Manifest{}, err } - if err := importTable(ctx, tx, opts.RepoPath, table); err != nil { + opts.reportProgress(ImportProgress{Phase: "table_start", Table: table.Name, TotalRows: table.Rows}) + if err := importTable(ctx, tx, opts, table); err != nil { return Manifest{}, err } + opts.reportProgress(ImportProgress{Phase: "table_done", Table: table.Name, TotalRows: table.Rows}) } + opts.reportProgress(ImportProgress{Phase: "repair"}) if err := repairImportedGuildIDs(ctx, tx); err != nil { return Manifest{}, err } @@ -268,10 +285,12 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) return Manifest{}, err } } + opts.reportProgress(ImportProgress{Phase: "commit"}) if err := tx.Commit(); err != nil { return Manifest{}, err } committed = true + opts.reportProgress(ImportProgress{Phase: "rebuild_fts"}) if err := s.RebuildSearchIndexes(ctx); err != nil { return Manifest{}, err } @@ -282,6 +301,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) return Manifest{}, err } pragmasRestored = true + opts.reportProgress(ImportProgress{Phase: "done", TotalRows: manifestRowCount(manifest)}) return manifest, nil } @@ -335,6 +355,23 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes return imported, true, nil } +func (opts Options) reportProgress(progress ImportProgress) { + if opts.Progress != nil { + opts.Progress(progress) + } +} + +func manifestRowCount(manifest Manifest) int { + total := 0 + for _, table := range manifest.Tables { + total += table.Rows + } + for _, embeddings := range manifest.Embeddings { + total += embeddings.Rows + } + return total +} + func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error { tx, err := s.DB().BeginTx(ctx, nil) if err != nil { @@ -572,7 +609,7 @@ func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingM }, nil } -func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error { +func importTable(ctx context.Context, tx *sql.Tx, opts Options, table TableManifest) error { files := table.Files if len(files) == 0 && strings.TrimSpace(table.File) != "" { files = []string{table.File} @@ -586,34 +623,38 @@ func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableMa return fmt.Errorf("prepare import %s: %w", table.Name, err) } defer func() { _ = stmt.Close() }() - for _, rel := range files { + for i, rel := range files { if err := ctx.Err(); err != nil { return err } - if err := importTableFile(ctx, stmt, repoPath, table, columns, rel); err != nil { + opts.reportProgress(ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), TotalRows: table.Rows}) + rows, err := importTableFile(ctx, stmt, opts.RepoPath, table, columns, rel) + if err != nil { return err } + opts.reportProgress(ImportProgress{Phase: "file_done", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), Rows: rows, TotalRows: table.Rows}) } return nil } -func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) error { +func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) (int, error) { path := filepath.Join(repoPath, filepath.FromSlash(rel)) file, err := os.Open(path) if err != nil { - return fmt.Errorf("open %s: %w", rel, err) + return 0, fmt.Errorf("open %s: %w", rel, err) } defer func() { _ = file.Close() }() gz, err := gzip.NewReader(file) if err != nil { - return fmt.Errorf("read gzip %s: %w", rel, err) + return 0, fmt.Errorf("read gzip %s: %w", rel, err) } defer func() { _ = gz.Close() }() dec := json.NewDecoder(gz) dec.UseNumber() + count := 0 for { if err := ctx.Err(); err != nil { - return err + return count, err } row := map[string]any{} err := dec.Decode(&row) @@ -621,7 +662,7 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table break } if err != nil { - return fmt.Errorf("decode %s: %w", rel, err) + return count, fmt.Errorf("decode %s: %w", rel, err) } if isDirectMessageSnapshotRow(table.Name, row) { continue @@ -631,10 +672,11 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table values[i] = importValue(row[column]) } if _, err := stmt.ExecContext(ctx, values...); err != nil { - return fmt.Errorf("insert %s: %w", table.Name, err) + return count, fmt.Errorf("insert %s: %w", table.Name, err) } + count++ } - return nil + return count, nil } func repairImportedGuildIDs(ctx context.Context, tx *sql.Tx) error { diff --git a/internal/share/share_test.go b/internal/share/share_test.go index cc4971c..787d6a9 100644 --- a/internal/share/share_test.go +++ b/internal/share/share_test.go @@ -35,10 +35,20 @@ func TestExportImportRoundTrip(t *testing.T) { require.NoError(t, err) defer func() { _ = dst.Close() }() - imported, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, Branch: "main"}) + var progress []ImportProgress + imported, changed, err := ImportIfChanged(ctx, dst, Options{ + RepoPath: repo, + Branch: "main", + Progress: func(p ImportProgress) { progress = append(progress, p) }, + }) require.NoError(t, err) require.True(t, changed) require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt) + require.Contains(t, progressPhases(progress), "start") + require.Contains(t, progressPhases(progress), "table_start") + require.Contains(t, progressPhases(progress), "file_done") + require.Contains(t, progressPhases(progress), "rebuild_fts") + require.Contains(t, progressPhases(progress), "done") results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "launch", Limit: 10}) require.NoError(t, err) @@ -673,7 +683,7 @@ func TestShareSmallHelpersAndValidation(t *testing.T) { defer func() { _ = s.Close() }() tx, err := s.DB().BeginTx(ctx, nil) require.NoError(t, err) - require.ErrorContains(t, importTable(ctx, tx, t.TempDir(), TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files") + require.ErrorContains(t, importTable(ctx, tx, Options{RepoPath: t.TempDir()}, TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files") require.NoError(t, tx.Rollback()) require.ErrorContains(t, ImportEmbeddings(ctx, s, Options{ @@ -727,7 +737,7 @@ func TestLegacyManifestFileImportAndEmbeddingDecodeErrors(t *testing.T) { }) tx, err := s.DB().BeginTx(ctx, nil) require.NoError(t, err) - require.NoError(t, importTable(ctx, tx, repo, TableManifest{ + require.NoError(t, importTable(ctx, tx, Options{RepoPath: repo}, TableManifest{ Name: "guilds", File: tableRel, Columns: []string{"id", "name", "icon", "raw_json", "updated_at"}, @@ -935,3 +945,11 @@ func tableNames(manifest Manifest) []string { } return names } + +func progressPhases(progress []ImportProgress) []string { + phases := make([]string, 0, len(progress)) + for _, item := range progress { + phases = append(phases, item.Phase) + } + return phases +} diff --git a/internal/store/store.go b/internal/store/store.go index 822df3e..9d57b72 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -413,6 +413,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { `create index if not exists idx_members_guild_id on members(guild_id);`, `create index if not exists idx_messages_channel_id on messages(channel_id);`, `create index if not exists idx_messages_guild_id on messages(guild_id);`, + `create index if not exists idx_messages_created_id on messages(created_at, id);`, `create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`, `create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`, `create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`, @@ -488,6 +489,7 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error { `create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`, `create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`, `create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`, + `create index if not exists idx_messages_created_id on messages(created_at, id);`, `create index if not exists idx_mentions_guild_event on mention_events(guild_id, event_at, event_id);`, `create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`, `create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`, diff --git a/internal/store/store_test.go b/internal/store/store_test.go index f8b1ed7..6e9e4b4 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -1337,6 +1337,7 @@ func TestOpenCreatesQueryIndexes(t *testing.T) { defer func() { _ = s.Close() }() messageIndexes := indexNames(t, ctx, s.DB(), "messages") + require.Contains(t, messageIndexes, "idx_messages_created_id") require.Contains(t, messageIndexes, "idx_messages_guild_created_id") require.Contains(t, messageIndexes, "idx_messages_channel_created_id") require.Contains(t, messageIndexes, "idx_messages_author_created_id") @@ -1361,6 +1362,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) { "idx_messages_guild_created_id", "idx_messages_channel_created_id", "idx_messages_author_created_id", + "idx_messages_created_id", "idx_mentions_guild_event", "idx_mentions_channel_event", } { @@ -1376,6 +1378,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) { version, err := s.schemaVersion(ctx) require.NoError(t, err) require.Equal(t, storeSchemaVersion, version) + require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_created_id") require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_channel_created_id") require.Contains(t, indexNames(t, ctx, s.DB(), "mention_events"), "idx_mentions_guild_event") }