fix: keep discrawl sync update explicit

This commit is contained in:
Peter Steinberger 2026-05-03 15:03:41 +01:00
parent 45f0133b62
commit c934c579b0
No known key found for this signature in database
12 changed files with 336 additions and 56 deletions

View File

@ -7,6 +7,9 @@ All notable changes to `discrawl` will be documented in this file.
### Fixes
- `discrawl` now handles SIGINT/SIGTERM by canceling active sync/import contexts so large SQLite and FTS writes can roll back and close cleanly instead of being terminated mid-transaction.
- `discrawl sync` now keeps Git snapshot refreshes explicit by default; use `--update=auto` or `--update=force` when you want a sync run to pull/import the shared snapshot before live Discord or desktop-cache deltas.
- Snapshot imports now emit phase/table/file progress and keep the sync lock file updated with the active phase, making long update/import runs diagnosable instead of looking hung.
- Recent-message scans are backed by a plain `messages(created_at, id)` index so archive freshness and short-window analysis queries avoid full-table scans.
### Maintenance

View File

@ -173,15 +173,20 @@ discrawl init --db ~/data/discrawl.db
Refreshes SQLite from one or both archive sources.
By default, `sync` runs both sources:
By default, `sync` runs both live/local sources and does not import the Git snapshot first:
- Discord bot-token sync for bot-visible guild data
- local Discord Desktop cache import for classifiable cached messages and proven DMs
Use `discrawl update` when you want to pull/import the shared Git snapshot. If you intentionally want a sync run to import the snapshot before live deltas, pass `--update=auto` to import only when stale or `--update=force` to pull/import before syncing. `--no-update` is accepted as an explicit no-op alias for the default.
Run one explicit `--full` pass when you want a complete historical guild archive. Use plain `sync` afterward for frequent latest-message and desktop-cache refreshes.
```bash
discrawl sync
discrawl sync --update=auto
discrawl sync --update=force
discrawl sync --no-update
discrawl sync --full
discrawl sync --full --all
discrawl sync --guild 123456789012345678
@ -207,7 +212,8 @@ Bot sync modes:
| Command | Use when | Behavior |
| --- | --- | --- |
| `discrawl sync` | routine refresh | imports any stale Git snapshot first, skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor |
| `discrawl sync` | routine refresh | skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor |
| `discrawl sync --update=auto` | hybrid Git/live refresh | imports a stale Git snapshot first, then runs the routine live refresh |
| `discrawl sync --all-channels` | repair pass | broad incremental sweep across every stored channel/thread, including archived threads |
| `discrawl sync --full` | historical backfill | crawls older history until channels are complete; can take a long time on large servers |
@ -218,7 +224,7 @@ Bot sync modes:
`--latest-only` is still accepted for explicit latest-only runs; it is now the default for untargeted `sync`. Use `--all-channels` to opt out of the fast default without doing a full historical crawl.
When `--channels` includes a forum channel id, `discrawl` expands that forum's threads and syncs their messages as part of the targeted run.
`--since` limits initial history/bootstrap and full-history backfill to messages at or after the given RFC3339 timestamp. It does not mark older history as complete, so a later `sync --full` without `--since` can continue the backfill.
Long runs now emit periodic progress logs to stderr so large backfills do not look hung.
Long runs now emit periodic progress logs to stderr so large backfills and Git snapshot imports do not look hung.
If in-flight channels stop completing for a while, `discrawl` now emits `message sync waiting` heartbeat logs with the oldest active channel, per-channel page activity, and skip/defer counters, and every run ends with a `message sync finished` summary.
Each channel crawl also has a bounded runtime budget, so a pathological channel is deferred and retried on the next sync instead of pinning a worker forever.
Full sync member refresh is best-effort and currently gives up after five minutes without a caller-supplied deadline, so message sync completion is not held hostage by a slow guild member crawl.
@ -454,9 +460,9 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive.
discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git
```
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually.
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually. `discrawl sync` does not auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast.
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
Git snapshots publish non-DM archive tables by default. Embedding queue state stays local to each machine, and Git-only readers can use FTS immediately without an embedding provider.

View File

@ -113,9 +113,19 @@ func (r *runtime) runSync(args []string) error {
latestOnly := fs.Bool("latest-only", false, "")
guildsFlag := fs.String("guilds", "", "")
guildFlag := fs.String("guild", "", "")
updateMode := fs.String("update", "", "")
noUpdate := fs.Bool("no-update", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
if *noUpdate && strings.TrimSpace(*updateMode) != "" && !strings.EqualFold(strings.TrimSpace(*updateMode), string(shareUpdateNever)) {
return usageErr(errors.New("use either --no-update or --update, not both"))
}
if strings.TrimSpace(*updateMode) != "" {
if _, err := parseShareUpdateMode(*updateMode); err != nil {
return usageErr(err)
}
}
sources, err := parseSyncSources(*source)
if err != nil {
return usageErr(err)
@ -151,6 +161,7 @@ func (r *runtime) runSync(args []string) error {
func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) error {
var apiStats *syncer.SyncStats
if sources.discord {
r.setSyncLockPhase("discord sync")
shouldClose := r.client == nil
if err := r.ensureDiscordServices(); err != nil {
return err
@ -166,6 +177,7 @@ func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) er
}
var wiretapStats *discorddesktop.Stats
if sources.wiretap {
r.setSyncLockPhase("wiretap import")
stats, err := discorddesktop.Import(r.ctx, r.store, discorddesktop.Options{
Path: r.cfg.Desktop.Path,
MaxFileBytes: r.cfg.Desktop.MaxFileBytes,

View File

@ -90,23 +90,24 @@ func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error {
}
type runtime struct {
ctx context.Context
configPath string
cfg config.Config
stdout io.Writer
stderr io.Writer
json bool
plain bool
logger *slog.Logger
store *store.Store
client discordClient
syncer syncService
dbLockHeld bool
openStore func(context.Context, string) (*store.Store, error)
newDiscord func(config.Config) (discordClient, error)
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
now func() time.Time
ctx context.Context
configPath string
cfg config.Config
stdout io.Writer
stderr io.Writer
json bool
plain bool
logger *slog.Logger
store *store.Store
client discordClient
syncer syncService
dbLockHeld bool
lockStarted time.Time
openStore func(context.Context, string) (*store.Store, error)
newDiscord func(config.Config) (discordClient, error)
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
now func() time.Time
}
type discordClient interface {
@ -131,7 +132,11 @@ func (r *runtime) dispatch(rest []string) error {
case "init":
return r.runInit(rest[1:])
case "sync":
return r.withLocalStoreLocked(true, func() error { return r.runSync(rest[1:]) })
updateMode, err := syncShareUpdateMode(rest[1:])
if err != nil {
return usageErr(err)
}
return r.withLocalStoreUpdateLocked(updateMode, true, func() error { return r.runSync(rest[1:]) })
case "tail":
return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) })
case "wiretap":
@ -187,14 +192,18 @@ func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error {
}
func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error {
return r.withLocalStoreDefaultLocked(autoShareUpdate, true, fn)
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn)
}
func (r *runtime) withLocalStoreDefault(autoShareUpdate bool, fn func() error) error {
return r.withLocalStoreDefaultLocked(autoShareUpdate, false, fn)
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), false, fn)
}
func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn func() error) error {
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), lockDB, fn)
}
func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
@ -215,13 +224,13 @@ func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn f
r.cfg = cfg
if lockDB {
return r.withSyncLock(func() error {
return r.openLocalStore(dbPath, autoShareUpdate, fn)
return r.openLocalStore(dbPath, updateMode, fn)
})
}
return r.openLocalStore(dbPath, autoShareUpdate, fn)
return r.openLocalStore(dbPath, updateMode, fn)
}
func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func() error) error {
func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
@ -232,8 +241,8 @@ func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func()
return dbErr(err)
}
defer func() { _ = r.store.Close() }()
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(); err != nil {
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(updateMode); err != nil {
return err
}
}
@ -245,6 +254,10 @@ func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func()
}
func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error {
return r.withServicesUpdateLocked(withDiscord, boolShareUpdateMode(autoShareUpdate), lockDB, fn)
}
func (r *runtime) withServicesUpdateLocked(withDiscord bool, updateMode shareUpdateMode, lockDB bool, fn func() error) error {
cfg, err := config.Load(r.configPath)
if err != nil {
return configErr(err)
@ -259,13 +272,13 @@ func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bo
r.cfg = cfg
if lockDB {
return r.withSyncLock(func() error {
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
return r.openServices(dbPath, withDiscord, updateMode, fn)
})
}
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
return r.openServices(dbPath, withDiscord, updateMode, fn)
}
func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool, fn func() error) error {
func (r *runtime) openServices(dbPath string, withDiscord bool, updateMode shareUpdateMode, fn func() error) error {
storeFactory := r.openStore
if storeFactory == nil {
storeFactory = store.Open
@ -276,8 +289,8 @@ func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool,
return dbErr(err)
}
defer func() { _ = r.store.Close() }()
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(); err != nil {
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
if err := r.autoUpdateShare(updateMode); err != nil {
return err
}
}
@ -321,24 +334,27 @@ func (r *runtime) ensureDiscordServices() error {
return nil
}
func (r *runtime) autoUpdateShare() error {
if !r.cfg.ShareEnabled() || !r.cfg.Share.AutoUpdate {
func (r *runtime) autoUpdateShare(mode shareUpdateMode) error {
if !r.cfg.ShareEnabled() || (mode == shareUpdateConfigured && !r.cfg.Share.AutoUpdate) {
return nil
}
staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter)
if err != nil {
return configErr(fmt.Errorf("invalid share.stale_after: %w", err))
}
if !share.NeedsImport(r.ctx, r.store, staleAfter) {
if mode != shareUpdateForce && !share.NeedsImport(r.ctx, r.store, staleAfter) {
return nil
}
opts, err := r.shareOptions()
if err != nil {
return err
}
r.setSyncLockPhase("share pull")
r.logger.Info("share update pulling", "repo_path", opts.RepoPath, "remote", opts.Remote)
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
r.setSyncLockPhase("share import")
_, _, err = share.ImportIfChanged(r.ctx, r.store, opts)
if errors.Is(err, share.ErrNoManifest) {
return nil
@ -355,5 +371,6 @@ func (r *runtime) shareOptions() (share.Options, error) {
RepoPath: repoPath,
Remote: r.cfg.Share.Remote,
Branch: r.cfg.Share.Branch,
Progress: r.shareProgress,
}, nil
}

View File

@ -600,7 +600,7 @@ func TestShareUpdateImportsNewRemoteSnapshot(t *testing.T) {
require.Contains(t, out.String(), "newer git snapshot arrived")
}
func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) {
func TestSyncSkipsGitShareByDefaultAndCanImportBeforeLiveDiscord(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
remoteRepo := filepath.Join(dir, "remote.git")
@ -643,17 +643,33 @@ func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) {
}
require.NoError(t, rt.dispatch([]string{"sync", "--all"}))
require.True(t, hybrid.sawGitMessage)
require.False(t, hybrid.sawGitMessage)
reader, err := store.Open(ctx, cfg.DBPath)
require.NoError(t, err)
defer func() { _ = reader.Close() }()
rows, err := reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true})
require.NoError(t, err)
contents := make([]string, 0, len(rows))
for _, row := range rows {
contents = append(contents, row.Content)
}
require.NotContains(t, contents, "automatic updates work")
require.Contains(t, contents, "live discord filled the delta")
require.NoError(t, reader.Close())
hybrid.sawGitMessage = false
require.NoError(t, rt.dispatch([]string{"sync", "--all", "--update=auto"}))
require.True(t, hybrid.sawGitMessage)
reader, err = store.Open(ctx, cfg.DBPath)
require.NoError(t, err)
defer func() { _ = reader.Close() }()
rows, err = reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true})
require.NoError(t, err)
contents = contents[:0]
for _, row := range rows {
contents = append(contents, row.Content)
}
require.Contains(t, contents, "automatic updates work")
require.Contains(t, contents, "live discord filled the delta")
}
@ -1543,6 +1559,17 @@ func TestHelpers(t *testing.T) {
require.Equal(t, []string{"a", "b"}, csvList("a,b,a"))
require.Equal(t, "x", (&cliError{code: 2, err: assertErr("x")}).Error())
mode, err := syncShareUpdateMode([]string{"--all"})
require.NoError(t, err)
require.Equal(t, shareUpdateNever, mode)
mode, err = syncShareUpdateMode([]string{"--update=auto"})
require.NoError(t, err)
require.Equal(t, shareUpdateAuto, mode)
mode, err = syncShareUpdateMode([]string{"--update", "force"})
require.NoError(t, err)
require.Equal(t, shareUpdateForce, mode)
_, err = syncShareUpdateMode([]string{"--update"})
require.Error(t, err)
require.Equal(t, 2, ExitCode(usageErr(assertErr("x"))))
require.Equal(t, 4, ExitCode(authErr(assertErr("x"))))
require.Equal(t, 5, ExitCode(dbErr(assertErr("x"))))
@ -1926,6 +1953,8 @@ func TestCommandUsageErrors(t *testing.T) {
require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "-1"})))
require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "1", "--since", "2026-03-01T00:00:00Z"})))
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--all", "--guild", "g1"})))
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update", "bogus"})))
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update=force", "--no-update"})))
require.Equal(t, 2, ExitCode(rt.runChannels(nil)))
require.Equal(t, 2, ExitCode(rt.runStatus([]string{"extra"})))
require.NoError(t, (&runtime{stdout: &bytes.Buffer{}}).runDoctor(nil))

View File

@ -136,13 +136,15 @@ func (r *runtime) runSubscribe(args []string) error {
if err != nil {
return configErr(err)
}
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch}
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch, Progress: r.shareProgress}
if *withEmbeddings {
applyEmbeddingShareOptions(&opts, cfg)
}
r.setSyncLockPhase("share pull")
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
r.setSyncLockPhase("share import")
manifest, imported, err := share.ImportIfChanged(r.ctx, s, opts)
if err != nil {
return err
@ -176,12 +178,15 @@ func (r *runtime) runUpdate(args []string) error {
if err != nil {
return err
}
opts.Progress = r.shareProgress
if *withEmbeddings {
applyEmbeddingShareOptions(&opts, r.cfg)
}
r.setSyncLockPhase("share pull")
if err := share.Pull(r.ctx, opts); err != nil {
return err
}
r.setSyncLockPhase("share import")
manifest, imported, err := share.ImportIfChanged(r.ctx, r.store, opts)
if err != nil {
return err

View File

@ -0,0 +1,110 @@
package cli
import (
"errors"
"fmt"
"strings"
"time"
"github.com/steipete/discrawl/internal/share"
)
type shareUpdateMode string
const (
shareUpdateConfigured shareUpdateMode = "configured"
shareUpdateAuto shareUpdateMode = "auto"
shareUpdateNever shareUpdateMode = "never"
shareUpdateForce shareUpdateMode = "force"
)
func boolShareUpdateMode(enabled bool) shareUpdateMode {
if enabled {
return shareUpdateConfigured
}
return shareUpdateNever
}
func parseShareUpdateMode(raw string) (shareUpdateMode, error) {
switch shareUpdateMode(strings.ToLower(strings.TrimSpace(raw))) {
case "", shareUpdateAuto:
return shareUpdateAuto, nil
case shareUpdateNever:
return shareUpdateNever, nil
case shareUpdateForce:
return shareUpdateForce, nil
default:
return "", fmt.Errorf("invalid --update %q; use auto, never, or force", raw)
}
}
func syncShareUpdateMode(args []string) (shareUpdateMode, error) {
mode := shareUpdateNever
sawNoUpdate := false
sawUpdate := false
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case arg == "--no-update":
sawNoUpdate = true
mode = shareUpdateNever
case arg == "--update":
if i+1 >= len(args) {
return "", errors.New("--update requires auto, never, or force")
}
parsed, err := parseShareUpdateMode(args[i+1])
if err != nil {
return "", err
}
sawUpdate = true
mode = parsed
i++
case strings.HasPrefix(arg, "--update="):
parsed, err := parseShareUpdateMode(strings.TrimPrefix(arg, "--update="))
if err != nil {
return "", err
}
sawUpdate = true
mode = parsed
}
}
if sawNoUpdate && sawUpdate && mode != shareUpdateNever {
return "", errors.New("use either --no-update or --update, not both")
}
return mode, nil
}
func (r *runtime) shareProgress(progress share.ImportProgress) {
if progress.Phase == "" {
return
}
phase := "share " + progress.Phase
if progress.Table != "" {
phase += " " + progress.Table
}
if progress.File != "" {
phase += " " + progress.File
}
r.setSyncLockPhase(phase)
attrs := []any{"phase", progress.Phase}
if progress.Table != "" {
attrs = append(attrs, "table", progress.Table)
}
if progress.Rows != 0 {
attrs = append(attrs, "rows", progress.Rows)
}
if progress.TotalRows != 0 {
attrs = append(attrs, "total_rows", progress.TotalRows)
}
if progress.File != "" {
attrs = append(attrs, "file", progress.File, "file_index", progress.FileIndex, "file_count", progress.FileCount)
}
r.logger.Info("share import progress", attrs...)
}
func (r *runtime) nowUTC() time.Time {
if r.now != nil {
return r.now().UTC()
}
return time.Now().UTC()
}

View File

@ -3,7 +3,10 @@ package cli
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/steipete/discrawl/internal/config"
)
@ -21,13 +24,37 @@ func (r *runtime) withSyncLock(fn func() error) error {
return err
}
r.dbLockHeld = true
r.lockStarted = r.nowUTC()
r.setSyncLockPhase("locked")
defer func() {
r.dbLockHeld = false
r.lockStarted = time.Time{}
_ = release()
}()
return fn()
}
func (r *runtime) setSyncLockPhase(phase string) {
if !r.dbLockHeld {
return
}
path, err := r.syncLockPath()
if err != nil {
return
}
started := r.lockStarted
if started.IsZero() {
started = r.nowUTC()
}
body := fmt.Sprintf("pid=%d\nstarted_at=%s\nupdated_at=%s\nphase=%s\n",
os.Getpid(),
started.Format(time.RFC3339Nano),
r.nowUTC().Format(time.RFC3339Nano),
phase,
)
_ = os.WriteFile(path, []byte(body), 0o600)
}
func (r *runtime) syncLockPath() (string, error) {
dbPath, err := config.ExpandPath(r.cfg.DBPath)
if err != nil {
@ -38,6 +65,12 @@ func (r *runtime) syncLockPath() (string, error) {
func syncLockErr(ctx context.Context, path string) error {
if ctx.Err() != nil {
if body, err := os.ReadFile(path); err == nil {
details := strings.TrimSpace(string(body))
if details != "" {
return fmt.Errorf("wait for sync lock %s (%s): %w", path, strings.ReplaceAll(details, "\n", ", "), ctx.Err())
}
}
return fmt.Errorf("wait for sync lock %s: %w", path, ctx.Err())
}
return nil

View File

@ -52,6 +52,17 @@ type Options struct {
EmbeddingProvider string
EmbeddingModel string
EmbeddingInputVersion string
Progress func(ImportProgress)
}
type ImportProgress struct {
Phase string
Table string
File string
FileIndex int
FileCount int
Rows int
TotalRows int
}
type Manifest struct {
@ -221,6 +232,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
if err != nil {
return Manifest{}, err
}
opts.reportProgress(ImportProgress{Phase: "start", TotalRows: manifestRowCount(manifest)})
restorePragmas, err := applyImportPragmas(ctx, s.DB())
if err != nil {
return Manifest{}, err
@ -242,11 +254,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
}
}()
for _, table := range []string{"message_fts", "member_fts"} {
opts.reportProgress(ImportProgress{Phase: "drop_fts", Table: table})
if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil {
return Manifest{}, fmt.Errorf("drop %s: %w", table, err)
}
}
for _, table := range slices.Backward(SnapshotTables) {
opts.reportProgress(ImportProgress{Phase: "clear", Table: table})
query, args := snapshotDeleteQuery(table)
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return Manifest{}, fmt.Errorf("clear %s: %w", table, err)
@ -256,10 +270,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
if err := ctx.Err(); err != nil {
return Manifest{}, err
}
if err := importTable(ctx, tx, opts.RepoPath, table); err != nil {
opts.reportProgress(ImportProgress{Phase: "table_start", Table: table.Name, TotalRows: table.Rows})
if err := importTable(ctx, tx, opts, table); err != nil {
return Manifest{}, err
}
opts.reportProgress(ImportProgress{Phase: "table_done", Table: table.Name, TotalRows: table.Rows})
}
opts.reportProgress(ImportProgress{Phase: "repair"})
if err := repairImportedGuildIDs(ctx, tx); err != nil {
return Manifest{}, err
}
@ -268,10 +285,12 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
return Manifest{}, err
}
}
opts.reportProgress(ImportProgress{Phase: "commit"})
if err := tx.Commit(); err != nil {
return Manifest{}, err
}
committed = true
opts.reportProgress(ImportProgress{Phase: "rebuild_fts"})
if err := s.RebuildSearchIndexes(ctx); err != nil {
return Manifest{}, err
}
@ -282,6 +301,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
return Manifest{}, err
}
pragmasRestored = true
opts.reportProgress(ImportProgress{Phase: "done", TotalRows: manifestRowCount(manifest)})
return manifest, nil
}
@ -335,6 +355,23 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
return imported, true, nil
}
func (opts Options) reportProgress(progress ImportProgress) {
if opts.Progress != nil {
opts.Progress(progress)
}
}
func manifestRowCount(manifest Manifest) int {
total := 0
for _, table := range manifest.Tables {
total += table.Rows
}
for _, embeddings := range manifest.Embeddings {
total += embeddings.Rows
}
return total
}
func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error {
tx, err := s.DB().BeginTx(ctx, nil)
if err != nil {
@ -572,7 +609,7 @@ func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingM
}, nil
}
func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error {
func importTable(ctx context.Context, tx *sql.Tx, opts Options, table TableManifest) error {
files := table.Files
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
files = []string{table.File}
@ -586,34 +623,38 @@ func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableMa
return fmt.Errorf("prepare import %s: %w", table.Name, err)
}
defer func() { _ = stmt.Close() }()
for _, rel := range files {
for i, rel := range files {
if err := ctx.Err(); err != nil {
return err
}
if err := importTableFile(ctx, stmt, repoPath, table, columns, rel); err != nil {
opts.reportProgress(ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), TotalRows: table.Rows})
rows, err := importTableFile(ctx, stmt, opts.RepoPath, table, columns, rel)
if err != nil {
return err
}
opts.reportProgress(ImportProgress{Phase: "file_done", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), Rows: rows, TotalRows: table.Rows})
}
return nil
}
func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) error {
func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) (int, error) {
path := filepath.Join(repoPath, filepath.FromSlash(rel))
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("open %s: %w", rel, err)
return 0, fmt.Errorf("open %s: %w", rel, err)
}
defer func() { _ = file.Close() }()
gz, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("read gzip %s: %w", rel, err)
return 0, fmt.Errorf("read gzip %s: %w", rel, err)
}
defer func() { _ = gz.Close() }()
dec := json.NewDecoder(gz)
dec.UseNumber()
count := 0
for {
if err := ctx.Err(); err != nil {
return err
return count, err
}
row := map[string]any{}
err := dec.Decode(&row)
@ -621,7 +662,7 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
break
}
if err != nil {
return fmt.Errorf("decode %s: %w", rel, err)
return count, fmt.Errorf("decode %s: %w", rel, err)
}
if isDirectMessageSnapshotRow(table.Name, row) {
continue
@ -631,10 +672,11 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
values[i] = importValue(row[column])
}
if _, err := stmt.ExecContext(ctx, values...); err != nil {
return fmt.Errorf("insert %s: %w", table.Name, err)
return count, fmt.Errorf("insert %s: %w", table.Name, err)
}
count++
}
return nil
return count, nil
}
func repairImportedGuildIDs(ctx context.Context, tx *sql.Tx) error {

View File

@ -35,10 +35,20 @@ func TestExportImportRoundTrip(t *testing.T) {
require.NoError(t, err)
defer func() { _ = dst.Close() }()
imported, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, Branch: "main"})
var progress []ImportProgress
imported, changed, err := ImportIfChanged(ctx, dst, Options{
RepoPath: repo,
Branch: "main",
Progress: func(p ImportProgress) { progress = append(progress, p) },
})
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt)
require.Contains(t, progressPhases(progress), "start")
require.Contains(t, progressPhases(progress), "table_start")
require.Contains(t, progressPhases(progress), "file_done")
require.Contains(t, progressPhases(progress), "rebuild_fts")
require.Contains(t, progressPhases(progress), "done")
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "launch", Limit: 10})
require.NoError(t, err)
@ -673,7 +683,7 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
defer func() { _ = s.Close() }()
tx, err := s.DB().BeginTx(ctx, nil)
require.NoError(t, err)
require.ErrorContains(t, importTable(ctx, tx, t.TempDir(), TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files")
require.ErrorContains(t, importTable(ctx, tx, Options{RepoPath: t.TempDir()}, TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files")
require.NoError(t, tx.Rollback())
require.ErrorContains(t, ImportEmbeddings(ctx, s, Options{
@ -727,7 +737,7 @@ func TestLegacyManifestFileImportAndEmbeddingDecodeErrors(t *testing.T) {
})
tx, err := s.DB().BeginTx(ctx, nil)
require.NoError(t, err)
require.NoError(t, importTable(ctx, tx, repo, TableManifest{
require.NoError(t, importTable(ctx, tx, Options{RepoPath: repo}, TableManifest{
Name: "guilds",
File: tableRel,
Columns: []string{"id", "name", "icon", "raw_json", "updated_at"},
@ -935,3 +945,11 @@ func tableNames(manifest Manifest) []string {
}
return names
}
func progressPhases(progress []ImportProgress) []string {
phases := make([]string, 0, len(progress))
for _, item := range progress {
phases = append(phases, item.Phase)
}
return phases
}

View File

@ -413,6 +413,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
`create index if not exists idx_members_guild_id on members(guild_id);`,
`create index if not exists idx_messages_channel_id on messages(channel_id);`,
`create index if not exists idx_messages_guild_id on messages(guild_id);`,
`create index if not exists idx_messages_created_id on messages(created_at, id);`,
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
@ -488,6 +489,7 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
`create index if not exists idx_messages_created_id on messages(created_at, id);`,
`create index if not exists idx_mentions_guild_event on mention_events(guild_id, event_at, event_id);`,
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,

View File

@ -1337,6 +1337,7 @@ func TestOpenCreatesQueryIndexes(t *testing.T) {
defer func() { _ = s.Close() }()
messageIndexes := indexNames(t, ctx, s.DB(), "messages")
require.Contains(t, messageIndexes, "idx_messages_created_id")
require.Contains(t, messageIndexes, "idx_messages_guild_created_id")
require.Contains(t, messageIndexes, "idx_messages_channel_created_id")
require.Contains(t, messageIndexes, "idx_messages_author_created_id")
@ -1361,6 +1362,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) {
"idx_messages_guild_created_id",
"idx_messages_channel_created_id",
"idx_messages_author_created_id",
"idx_messages_created_id",
"idx_mentions_guild_event",
"idx_mentions_channel_event",
} {
@ -1376,6 +1378,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) {
version, err := s.schemaVersion(ctx)
require.NoError(t, err)
require.Equal(t, storeSchemaVersion, version)
require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_created_id")
require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_channel_created_id")
require.Contains(t, indexNames(t, ctx, s.DB(), "mention_events"), "idx_mentions_guild_event")
}