fix: keep discrawl sync update explicit
This commit is contained in:
parent
45f0133b62
commit
c934c579b0
@ -7,6 +7,9 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
### Fixes
|
||||
|
||||
- `discrawl` now handles SIGINT/SIGTERM by canceling active sync/import contexts so large SQLite and FTS writes can roll back and close cleanly instead of being terminated mid-transaction.
|
||||
- `discrawl sync` now keeps Git snapshot refreshes explicit by default; use `--update=auto` or `--update=force` when you want a sync run to pull/import the shared snapshot before live Discord or desktop-cache deltas.
|
||||
- Snapshot imports now emit phase/table/file progress and keep the sync lock file updated with the active phase, making long update/import runs diagnosable instead of looking hung.
|
||||
- Recent-message scans are backed by a plain `messages(created_at, id)` index so archive freshness and short-window analysis queries avoid full-table scans.
|
||||
|
||||
### Maintenance
|
||||
|
||||
|
||||
16
README.md
16
README.md
@ -173,15 +173,20 @@ discrawl init --db ~/data/discrawl.db
|
||||
|
||||
Refreshes SQLite from one or both archive sources.
|
||||
|
||||
By default, `sync` runs both sources:
|
||||
By default, `sync` runs both live/local sources and does not import the Git snapshot first:
|
||||
|
||||
- Discord bot-token sync for bot-visible guild data
|
||||
- local Discord Desktop cache import for classifiable cached messages and proven DMs
|
||||
|
||||
Use `discrawl update` when you want to pull/import the shared Git snapshot. If you intentionally want a sync run to import the snapshot before live deltas, pass `--update=auto` to import only when stale or `--update=force` to pull/import before syncing. `--no-update` is accepted as an explicit no-op alias for the default.
|
||||
|
||||
Run one explicit `--full` pass when you want a complete historical guild archive. Use plain `sync` afterward for frequent latest-message and desktop-cache refreshes.
|
||||
|
||||
```bash
|
||||
discrawl sync
|
||||
discrawl sync --update=auto
|
||||
discrawl sync --update=force
|
||||
discrawl sync --no-update
|
||||
discrawl sync --full
|
||||
discrawl sync --full --all
|
||||
discrawl sync --guild 123456789012345678
|
||||
@ -207,7 +212,8 @@ Bot sync modes:
|
||||
|
||||
| Command | Use when | Behavior |
|
||||
| --- | --- | --- |
|
||||
| `discrawl sync` | routine refresh | imports any stale Git snapshot first, skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor |
|
||||
| `discrawl sync` | routine refresh | skips member refreshes, checks live top-level channels plus active threads, and only fetches new messages for channels with a stored latest cursor |
|
||||
| `discrawl sync --update=auto` | hybrid Git/live refresh | imports a stale Git snapshot first, then runs the routine live refresh |
|
||||
| `discrawl sync --all-channels` | repair pass | broad incremental sweep across every stored channel/thread, including archived threads |
|
||||
| `discrawl sync --full` | historical backfill | crawls older history until channels are complete; can take a long time on large servers |
|
||||
|
||||
@ -218,7 +224,7 @@ Bot sync modes:
|
||||
`--latest-only` is still accepted for explicit latest-only runs; it is now the default for untargeted `sync`. Use `--all-channels` to opt out of the fast default without doing a full historical crawl.
|
||||
When `--channels` includes a forum channel id, `discrawl` expands that forum's threads and syncs their messages as part of the targeted run.
|
||||
`--since` limits initial history/bootstrap and full-history backfill to messages at or after the given RFC3339 timestamp. It does not mark older history as complete, so a later `sync --full` without `--since` can continue the backfill.
|
||||
Long runs now emit periodic progress logs to stderr so large backfills do not look hung.
|
||||
Long runs now emit periodic progress logs to stderr so large backfills and Git snapshot imports do not look hung.
|
||||
If in-flight channels stop completing for a while, `discrawl` now emits `message sync waiting` heartbeat logs with the oldest active channel, per-channel page activity, and skip/defer counters, and every run ends with a `message sync finished` summary.
|
||||
Each channel crawl also has a bounded runtime budget, so a pathological channel is deferred and retried on the next sync instead of pinning a worker forever.
|
||||
Full sync member refresh is best-effort and currently gives up after five minutes without a caller-supplied deadline, so message sync completion is not held hostage by a slow guild member crawl.
|
||||
@ -454,9 +460,9 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive.
|
||||
discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git
|
||||
```
|
||||
|
||||
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually.
|
||||
Once `share.remote` is configured, read commands auto-fetch and import when the local share import is older than `share.stale_after` (default `15m`). `discrawl update` forces the same pull/import step manually. `discrawl sync` does not auto-import the share unless `--update=auto` or `--update=force` is provided, so routine live refreshes stay fast.
|
||||
|
||||
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
|
||||
Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass.
|
||||
|
||||
Git snapshots publish non-DM archive tables by default. Embedding queue state stays local to each machine, and Git-only readers can use FTS immediately without an embedding provider.
|
||||
|
||||
|
||||
@ -113,9 +113,19 @@ func (r *runtime) runSync(args []string) error {
|
||||
latestOnly := fs.Bool("latest-only", false, "")
|
||||
guildsFlag := fs.String("guilds", "", "")
|
||||
guildFlag := fs.String("guild", "", "")
|
||||
updateMode := fs.String("update", "", "")
|
||||
noUpdate := fs.Bool("no-update", false, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
if *noUpdate && strings.TrimSpace(*updateMode) != "" && !strings.EqualFold(strings.TrimSpace(*updateMode), string(shareUpdateNever)) {
|
||||
return usageErr(errors.New("use either --no-update or --update, not both"))
|
||||
}
|
||||
if strings.TrimSpace(*updateMode) != "" {
|
||||
if _, err := parseShareUpdateMode(*updateMode); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
}
|
||||
sources, err := parseSyncSources(*source)
|
||||
if err != nil {
|
||||
return usageErr(err)
|
||||
@ -151,6 +161,7 @@ func (r *runtime) runSync(args []string) error {
|
||||
func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) error {
|
||||
var apiStats *syncer.SyncStats
|
||||
if sources.discord {
|
||||
r.setSyncLockPhase("discord sync")
|
||||
shouldClose := r.client == nil
|
||||
if err := r.ensureDiscordServices(); err != nil {
|
||||
return err
|
||||
@ -166,6 +177,7 @@ func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) er
|
||||
}
|
||||
var wiretapStats *discorddesktop.Stats
|
||||
if sources.wiretap {
|
||||
r.setSyncLockPhase("wiretap import")
|
||||
stats, err := discorddesktop.Import(r.ctx, r.store, discorddesktop.Options{
|
||||
Path: r.cfg.Desktop.Path,
|
||||
MaxFileBytes: r.cfg.Desktop.MaxFileBytes,
|
||||
|
||||
@ -90,23 +90,24 @@ func Run(ctx context.Context, args []string, stdout, stderr io.Writer) error {
|
||||
}
|
||||
|
||||
type runtime struct {
|
||||
ctx context.Context
|
||||
configPath string
|
||||
cfg config.Config
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
json bool
|
||||
plain bool
|
||||
logger *slog.Logger
|
||||
store *store.Store
|
||||
client discordClient
|
||||
syncer syncService
|
||||
dbLockHeld bool
|
||||
openStore func(context.Context, string) (*store.Store, error)
|
||||
newDiscord func(config.Config) (discordClient, error)
|
||||
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
|
||||
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
|
||||
now func() time.Time
|
||||
ctx context.Context
|
||||
configPath string
|
||||
cfg config.Config
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
json bool
|
||||
plain bool
|
||||
logger *slog.Logger
|
||||
store *store.Store
|
||||
client discordClient
|
||||
syncer syncService
|
||||
dbLockHeld bool
|
||||
lockStarted time.Time
|
||||
openStore func(context.Context, string) (*store.Store, error)
|
||||
newDiscord func(config.Config) (discordClient, error)
|
||||
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
|
||||
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type discordClient interface {
|
||||
@ -131,7 +132,11 @@ func (r *runtime) dispatch(rest []string) error {
|
||||
case "init":
|
||||
return r.runInit(rest[1:])
|
||||
case "sync":
|
||||
return r.withLocalStoreLocked(true, func() error { return r.runSync(rest[1:]) })
|
||||
updateMode, err := syncShareUpdateMode(rest[1:])
|
||||
if err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
return r.withLocalStoreUpdateLocked(updateMode, true, func() error { return r.runSync(rest[1:]) })
|
||||
case "tail":
|
||||
return r.withServicesLocked(true, func() error { return r.runTail(rest[1:]) })
|
||||
case "wiretap":
|
||||
@ -187,14 +192,18 @@ func (r *runtime) withServicesLocked(withDiscord bool, fn func() error) error {
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreLocked(autoShareUpdate bool, fn func() error) error {
|
||||
return r.withLocalStoreDefaultLocked(autoShareUpdate, true, fn)
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), true, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreDefault(autoShareUpdate bool, fn func() error) error {
|
||||
return r.withLocalStoreDefaultLocked(autoShareUpdate, false, fn)
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), false, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn func() error) error {
|
||||
return r.withLocalStoreUpdateLocked(boolShareUpdateMode(autoShareUpdate), lockDB, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withLocalStoreUpdateLocked(updateMode shareUpdateMode, lockDB bool, fn func() error) error {
|
||||
cfg, err := config.Load(r.configPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@ -215,13 +224,13 @@ func (r *runtime) withLocalStoreDefaultLocked(autoShareUpdate, lockDB bool, fn f
|
||||
r.cfg = cfg
|
||||
if lockDB {
|
||||
return r.withSyncLock(func() error {
|
||||
return r.openLocalStore(dbPath, autoShareUpdate, fn)
|
||||
return r.openLocalStore(dbPath, updateMode, fn)
|
||||
})
|
||||
}
|
||||
return r.openLocalStore(dbPath, autoShareUpdate, fn)
|
||||
return r.openLocalStore(dbPath, updateMode, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func() error) error {
|
||||
func (r *runtime) openLocalStore(dbPath string, updateMode shareUpdateMode, fn func() error) error {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
storeFactory = store.Open
|
||||
@ -232,8 +241,8 @@ func (r *runtime) openLocalStore(dbPath string, autoShareUpdate bool, fn func()
|
||||
return dbErr(err)
|
||||
}
|
||||
defer func() { _ = r.store.Close() }()
|
||||
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
||||
if err := r.autoUpdateShare(); err != nil {
|
||||
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
||||
if err := r.autoUpdateShare(updateMode); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -245,6 +254,10 @@ func (r *runtime) withServicesAuto(withDiscord, autoShareUpdate bool, fn func()
|
||||
}
|
||||
|
||||
func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bool, fn func() error) error {
|
||||
return r.withServicesUpdateLocked(withDiscord, boolShareUpdateMode(autoShareUpdate), lockDB, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) withServicesUpdateLocked(withDiscord bool, updateMode shareUpdateMode, lockDB bool, fn func() error) error {
|
||||
cfg, err := config.Load(r.configPath)
|
||||
if err != nil {
|
||||
return configErr(err)
|
||||
@ -259,13 +272,13 @@ func (r *runtime) withServicesAutoLocked(withDiscord, autoShareUpdate, lockDB bo
|
||||
r.cfg = cfg
|
||||
if lockDB {
|
||||
return r.withSyncLock(func() error {
|
||||
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
|
||||
return r.openServices(dbPath, withDiscord, updateMode, fn)
|
||||
})
|
||||
}
|
||||
return r.openServices(dbPath, withDiscord, autoShareUpdate, fn)
|
||||
return r.openServices(dbPath, withDiscord, updateMode, fn)
|
||||
}
|
||||
|
||||
func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool, fn func() error) error {
|
||||
func (r *runtime) openServices(dbPath string, withDiscord bool, updateMode shareUpdateMode, fn func() error) error {
|
||||
storeFactory := r.openStore
|
||||
if storeFactory == nil {
|
||||
storeFactory = store.Open
|
||||
@ -276,8 +289,8 @@ func (r *runtime) openServices(dbPath string, withDiscord, autoShareUpdate bool,
|
||||
return dbErr(err)
|
||||
}
|
||||
defer func() { _ = r.store.Close() }()
|
||||
if autoShareUpdate && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
||||
if err := r.autoUpdateShare(); err != nil {
|
||||
if updateMode != shareUpdateNever && os.Getenv("DISCRAWL_NO_AUTO_UPDATE") != "1" {
|
||||
if err := r.autoUpdateShare(updateMode); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -321,24 +334,27 @@ func (r *runtime) ensureDiscordServices() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *runtime) autoUpdateShare() error {
|
||||
if !r.cfg.ShareEnabled() || !r.cfg.Share.AutoUpdate {
|
||||
func (r *runtime) autoUpdateShare(mode shareUpdateMode) error {
|
||||
if !r.cfg.ShareEnabled() || (mode == shareUpdateConfigured && !r.cfg.Share.AutoUpdate) {
|
||||
return nil
|
||||
}
|
||||
staleAfter, err := time.ParseDuration(r.cfg.Share.StaleAfter)
|
||||
if err != nil {
|
||||
return configErr(fmt.Errorf("invalid share.stale_after: %w", err))
|
||||
}
|
||||
if !share.NeedsImport(r.ctx, r.store, staleAfter) {
|
||||
if mode != shareUpdateForce && !share.NeedsImport(r.ctx, r.store, staleAfter) {
|
||||
return nil
|
||||
}
|
||||
opts, err := r.shareOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.setSyncLockPhase("share pull")
|
||||
r.logger.Info("share update pulling", "repo_path", opts.RepoPath, "remote", opts.Remote)
|
||||
if err := share.Pull(r.ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
r.setSyncLockPhase("share import")
|
||||
_, _, err = share.ImportIfChanged(r.ctx, r.store, opts)
|
||||
if errors.Is(err, share.ErrNoManifest) {
|
||||
return nil
|
||||
@ -355,5 +371,6 @@ func (r *runtime) shareOptions() (share.Options, error) {
|
||||
RepoPath: repoPath,
|
||||
Remote: r.cfg.Share.Remote,
|
||||
Branch: r.cfg.Share.Branch,
|
||||
Progress: r.shareProgress,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -600,7 +600,7 @@ func TestShareUpdateImportsNewRemoteSnapshot(t *testing.T) {
|
||||
require.Contains(t, out.String(), "newer git snapshot arrived")
|
||||
}
|
||||
|
||||
func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) {
|
||||
func TestSyncSkipsGitShareByDefaultAndCanImportBeforeLiveDiscord(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
remoteRepo := filepath.Join(dir, "remote.git")
|
||||
@ -643,17 +643,33 @@ func TestSyncImportsGitShareBeforeLiveDiscord(t *testing.T) {
|
||||
}
|
||||
|
||||
require.NoError(t, rt.dispatch([]string{"sync", "--all"}))
|
||||
require.True(t, hybrid.sawGitMessage)
|
||||
require.False(t, hybrid.sawGitMessage)
|
||||
|
||||
reader, err := store.Open(ctx, cfg.DBPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = reader.Close() }()
|
||||
rows, err := reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true})
|
||||
require.NoError(t, err)
|
||||
contents := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
contents = append(contents, row.Content)
|
||||
}
|
||||
require.NotContains(t, contents, "automatic updates work")
|
||||
require.Contains(t, contents, "live discord filled the delta")
|
||||
require.NoError(t, reader.Close())
|
||||
|
||||
hybrid.sawGitMessage = false
|
||||
require.NoError(t, rt.dispatch([]string{"sync", "--all", "--update=auto"}))
|
||||
require.True(t, hybrid.sawGitMessage)
|
||||
|
||||
reader, err = store.Open(ctx, cfg.DBPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = reader.Close() }()
|
||||
rows, err = reader.ListMessages(ctx, store.MessageListOptions{Channel: "general", IncludeEmpty: true})
|
||||
require.NoError(t, err)
|
||||
contents = contents[:0]
|
||||
for _, row := range rows {
|
||||
contents = append(contents, row.Content)
|
||||
}
|
||||
require.Contains(t, contents, "automatic updates work")
|
||||
require.Contains(t, contents, "live discord filled the delta")
|
||||
}
|
||||
@ -1543,6 +1559,17 @@ func TestHelpers(t *testing.T) {
|
||||
|
||||
require.Equal(t, []string{"a", "b"}, csvList("a,b,a"))
|
||||
require.Equal(t, "x", (&cliError{code: 2, err: assertErr("x")}).Error())
|
||||
mode, err := syncShareUpdateMode([]string{"--all"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, shareUpdateNever, mode)
|
||||
mode, err = syncShareUpdateMode([]string{"--update=auto"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, shareUpdateAuto, mode)
|
||||
mode, err = syncShareUpdateMode([]string{"--update", "force"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, shareUpdateForce, mode)
|
||||
_, err = syncShareUpdateMode([]string{"--update"})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 2, ExitCode(usageErr(assertErr("x"))))
|
||||
require.Equal(t, 4, ExitCode(authErr(assertErr("x"))))
|
||||
require.Equal(t, 5, ExitCode(dbErr(assertErr("x"))))
|
||||
@ -1926,6 +1953,8 @@ func TestCommandUsageErrors(t *testing.T) {
|
||||
require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "-1"})))
|
||||
require.Equal(t, 2, ExitCode(rt.runMessages([]string{"--days", "1", "--since", "2026-03-01T00:00:00Z"})))
|
||||
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--all", "--guild", "g1"})))
|
||||
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update", "bogus"})))
|
||||
require.Equal(t, 2, ExitCode(rt.runSync([]string{"--update=force", "--no-update"})))
|
||||
require.Equal(t, 2, ExitCode(rt.runChannels(nil)))
|
||||
require.Equal(t, 2, ExitCode(rt.runStatus([]string{"extra"})))
|
||||
require.NoError(t, (&runtime{stdout: &bytes.Buffer{}}).runDoctor(nil))
|
||||
|
||||
@ -136,13 +136,15 @@ func (r *runtime) runSubscribe(args []string) error {
|
||||
if err != nil {
|
||||
return configErr(err)
|
||||
}
|
||||
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch}
|
||||
opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch, Progress: r.shareProgress}
|
||||
if *withEmbeddings {
|
||||
applyEmbeddingShareOptions(&opts, cfg)
|
||||
}
|
||||
r.setSyncLockPhase("share pull")
|
||||
if err := share.Pull(r.ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
r.setSyncLockPhase("share import")
|
||||
manifest, imported, err := share.ImportIfChanged(r.ctx, s, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -176,12 +178,15 @@ func (r *runtime) runUpdate(args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Progress = r.shareProgress
|
||||
if *withEmbeddings {
|
||||
applyEmbeddingShareOptions(&opts, r.cfg)
|
||||
}
|
||||
r.setSyncLockPhase("share pull")
|
||||
if err := share.Pull(r.ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
r.setSyncLockPhase("share import")
|
||||
manifest, imported, err := share.ImportIfChanged(r.ctx, r.store, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
110
internal/cli/share_update.go
Normal file
110
internal/cli/share_update.go
Normal file
@ -0,0 +1,110 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/discrawl/internal/share"
|
||||
)
|
||||
|
||||
type shareUpdateMode string
|
||||
|
||||
const (
|
||||
shareUpdateConfigured shareUpdateMode = "configured"
|
||||
shareUpdateAuto shareUpdateMode = "auto"
|
||||
shareUpdateNever shareUpdateMode = "never"
|
||||
shareUpdateForce shareUpdateMode = "force"
|
||||
)
|
||||
|
||||
func boolShareUpdateMode(enabled bool) shareUpdateMode {
|
||||
if enabled {
|
||||
return shareUpdateConfigured
|
||||
}
|
||||
return shareUpdateNever
|
||||
}
|
||||
|
||||
func parseShareUpdateMode(raw string) (shareUpdateMode, error) {
|
||||
switch shareUpdateMode(strings.ToLower(strings.TrimSpace(raw))) {
|
||||
case "", shareUpdateAuto:
|
||||
return shareUpdateAuto, nil
|
||||
case shareUpdateNever:
|
||||
return shareUpdateNever, nil
|
||||
case shareUpdateForce:
|
||||
return shareUpdateForce, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid --update %q; use auto, never, or force", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func syncShareUpdateMode(args []string) (shareUpdateMode, error) {
|
||||
mode := shareUpdateNever
|
||||
sawNoUpdate := false
|
||||
sawUpdate := false
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case arg == "--no-update":
|
||||
sawNoUpdate = true
|
||||
mode = shareUpdateNever
|
||||
case arg == "--update":
|
||||
if i+1 >= len(args) {
|
||||
return "", errors.New("--update requires auto, never, or force")
|
||||
}
|
||||
parsed, err := parseShareUpdateMode(args[i+1])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sawUpdate = true
|
||||
mode = parsed
|
||||
i++
|
||||
case strings.HasPrefix(arg, "--update="):
|
||||
parsed, err := parseShareUpdateMode(strings.TrimPrefix(arg, "--update="))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sawUpdate = true
|
||||
mode = parsed
|
||||
}
|
||||
}
|
||||
if sawNoUpdate && sawUpdate && mode != shareUpdateNever {
|
||||
return "", errors.New("use either --no-update or --update, not both")
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
func (r *runtime) shareProgress(progress share.ImportProgress) {
|
||||
if progress.Phase == "" {
|
||||
return
|
||||
}
|
||||
phase := "share " + progress.Phase
|
||||
if progress.Table != "" {
|
||||
phase += " " + progress.Table
|
||||
}
|
||||
if progress.File != "" {
|
||||
phase += " " + progress.File
|
||||
}
|
||||
r.setSyncLockPhase(phase)
|
||||
attrs := []any{"phase", progress.Phase}
|
||||
if progress.Table != "" {
|
||||
attrs = append(attrs, "table", progress.Table)
|
||||
}
|
||||
if progress.Rows != 0 {
|
||||
attrs = append(attrs, "rows", progress.Rows)
|
||||
}
|
||||
if progress.TotalRows != 0 {
|
||||
attrs = append(attrs, "total_rows", progress.TotalRows)
|
||||
}
|
||||
if progress.File != "" {
|
||||
attrs = append(attrs, "file", progress.File, "file_index", progress.FileIndex, "file_count", progress.FileCount)
|
||||
}
|
||||
r.logger.Info("share import progress", attrs...)
|
||||
}
|
||||
|
||||
func (r *runtime) nowUTC() time.Time {
|
||||
if r.now != nil {
|
||||
return r.now().UTC()
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
@ -3,7 +3,10 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/discrawl/internal/config"
|
||||
)
|
||||
@ -21,13 +24,37 @@ func (r *runtime) withSyncLock(fn func() error) error {
|
||||
return err
|
||||
}
|
||||
r.dbLockHeld = true
|
||||
r.lockStarted = r.nowUTC()
|
||||
r.setSyncLockPhase("locked")
|
||||
defer func() {
|
||||
r.dbLockHeld = false
|
||||
r.lockStarted = time.Time{}
|
||||
_ = release()
|
||||
}()
|
||||
return fn()
|
||||
}
|
||||
|
||||
func (r *runtime) setSyncLockPhase(phase string) {
|
||||
if !r.dbLockHeld {
|
||||
return
|
||||
}
|
||||
path, err := r.syncLockPath()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
started := r.lockStarted
|
||||
if started.IsZero() {
|
||||
started = r.nowUTC()
|
||||
}
|
||||
body := fmt.Sprintf("pid=%d\nstarted_at=%s\nupdated_at=%s\nphase=%s\n",
|
||||
os.Getpid(),
|
||||
started.Format(time.RFC3339Nano),
|
||||
r.nowUTC().Format(time.RFC3339Nano),
|
||||
phase,
|
||||
)
|
||||
_ = os.WriteFile(path, []byte(body), 0o600)
|
||||
}
|
||||
|
||||
func (r *runtime) syncLockPath() (string, error) {
|
||||
dbPath, err := config.ExpandPath(r.cfg.DBPath)
|
||||
if err != nil {
|
||||
@ -38,6 +65,12 @@ func (r *runtime) syncLockPath() (string, error) {
|
||||
|
||||
func syncLockErr(ctx context.Context, path string) error {
|
||||
if ctx.Err() != nil {
|
||||
if body, err := os.ReadFile(path); err == nil {
|
||||
details := strings.TrimSpace(string(body))
|
||||
if details != "" {
|
||||
return fmt.Errorf("wait for sync lock %s (%s): %w", path, strings.ReplaceAll(details, "\n", ", "), ctx.Err())
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("wait for sync lock %s: %w", path, ctx.Err())
|
||||
}
|
||||
return nil
|
||||
|
||||
@ -52,6 +52,17 @@ type Options struct {
|
||||
EmbeddingProvider string
|
||||
EmbeddingModel string
|
||||
EmbeddingInputVersion string
|
||||
Progress func(ImportProgress)
|
||||
}
|
||||
|
||||
type ImportProgress struct {
|
||||
Phase string
|
||||
Table string
|
||||
File string
|
||||
FileIndex int
|
||||
FileCount int
|
||||
Rows int
|
||||
TotalRows int
|
||||
}
|
||||
|
||||
type Manifest struct {
|
||||
@ -221,6 +232,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "start", TotalRows: manifestRowCount(manifest)})
|
||||
restorePragmas, err := applyImportPragmas(ctx, s.DB())
|
||||
if err != nil {
|
||||
return Manifest{}, err
|
||||
@ -242,11 +254,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
}
|
||||
}()
|
||||
for _, table := range []string{"message_fts", "member_fts"} {
|
||||
opts.reportProgress(ImportProgress{Phase: "drop_fts", Table: table})
|
||||
if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil {
|
||||
return Manifest{}, fmt.Errorf("drop %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
for _, table := range slices.Backward(SnapshotTables) {
|
||||
opts.reportProgress(ImportProgress{Phase: "clear", Table: table})
|
||||
query, args := snapshotDeleteQuery(table)
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return Manifest{}, fmt.Errorf("clear %s: %w", table, err)
|
||||
@ -256,10 +270,13 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
if err := ctx.Err(); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
if err := importTable(ctx, tx, opts.RepoPath, table); err != nil {
|
||||
opts.reportProgress(ImportProgress{Phase: "table_start", Table: table.Name, TotalRows: table.Rows})
|
||||
if err := importTable(ctx, tx, opts, table); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "table_done", Table: table.Name, TotalRows: table.Rows})
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "repair"})
|
||||
if err := repairImportedGuildIDs(ctx, tx); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
@ -268,10 +285,12 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
return Manifest{}, err
|
||||
}
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "commit"})
|
||||
if err := tx.Commit(); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
committed = true
|
||||
opts.reportProgress(ImportProgress{Phase: "rebuild_fts"})
|
||||
if err := s.RebuildSearchIndexes(ctx); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
@ -282,6 +301,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
return Manifest{}, err
|
||||
}
|
||||
pragmasRestored = true
|
||||
opts.reportProgress(ImportProgress{Phase: "done", TotalRows: manifestRowCount(manifest)})
|
||||
return manifest, nil
|
||||
}
|
||||
|
||||
@ -335,6 +355,23 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes
|
||||
return imported, true, nil
|
||||
}
|
||||
|
||||
func (opts Options) reportProgress(progress ImportProgress) {
|
||||
if opts.Progress != nil {
|
||||
opts.Progress(progress)
|
||||
}
|
||||
}
|
||||
|
||||
func manifestRowCount(manifest Manifest) int {
|
||||
total := 0
|
||||
for _, table := range manifest.Tables {
|
||||
total += table.Rows
|
||||
}
|
||||
for _, embeddings := range manifest.Embeddings {
|
||||
total += embeddings.Rows
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func ImportEmbeddings(ctx context.Context, s *store.Store, opts Options, manifest Manifest) error {
|
||||
tx, err := s.DB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@ -572,7 +609,7 @@ func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingM
|
||||
}, nil
|
||||
}
|
||||
|
||||
func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableManifest) error {
|
||||
func importTable(ctx context.Context, tx *sql.Tx, opts Options, table TableManifest) error {
|
||||
files := table.Files
|
||||
if len(files) == 0 && strings.TrimSpace(table.File) != "" {
|
||||
files = []string{table.File}
|
||||
@ -586,34 +623,38 @@ func importTable(ctx context.Context, tx *sql.Tx, repoPath string, table TableMa
|
||||
return fmt.Errorf("prepare import %s: %w", table.Name, err)
|
||||
}
|
||||
defer func() { _ = stmt.Close() }()
|
||||
for _, rel := range files {
|
||||
for i, rel := range files {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := importTableFile(ctx, stmt, repoPath, table, columns, rel); err != nil {
|
||||
opts.reportProgress(ImportProgress{Phase: "file_start", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), TotalRows: table.Rows})
|
||||
rows, err := importTableFile(ctx, stmt, opts.RepoPath, table, columns, rel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.reportProgress(ImportProgress{Phase: "file_done", Table: table.Name, File: rel, FileIndex: i + 1, FileCount: len(files), Rows: rows, TotalRows: table.Rows})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) error {
|
||||
func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table TableManifest, columns []string, rel string) (int, error) {
|
||||
path := filepath.Join(repoPath, filepath.FromSlash(rel))
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %s: %w", rel, err)
|
||||
return 0, fmt.Errorf("open %s: %w", rel, err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
gz, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read gzip %s: %w", rel, err)
|
||||
return 0, fmt.Errorf("read gzip %s: %w", rel, err)
|
||||
}
|
||||
defer func() { _ = gz.Close() }()
|
||||
dec := json.NewDecoder(gz)
|
||||
dec.UseNumber()
|
||||
count := 0
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
return count, err
|
||||
}
|
||||
row := map[string]any{}
|
||||
err := dec.Decode(&row)
|
||||
@ -621,7 +662,7 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode %s: %w", rel, err)
|
||||
return count, fmt.Errorf("decode %s: %w", rel, err)
|
||||
}
|
||||
if isDirectMessageSnapshotRow(table.Name, row) {
|
||||
continue
|
||||
@ -631,10 +672,11 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
|
||||
values[i] = importValue(row[column])
|
||||
}
|
||||
if _, err := stmt.ExecContext(ctx, values...); err != nil {
|
||||
return fmt.Errorf("insert %s: %w", table.Name, err)
|
||||
return count, fmt.Errorf("insert %s: %w", table.Name, err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func repairImportedGuildIDs(ctx context.Context, tx *sql.Tx) error {
|
||||
|
||||
@ -35,10 +35,20 @@ func TestExportImportRoundTrip(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = dst.Close() }()
|
||||
|
||||
imported, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, Branch: "main"})
|
||||
var progress []ImportProgress
|
||||
imported, changed, err := ImportIfChanged(ctx, dst, Options{
|
||||
RepoPath: repo,
|
||||
Branch: "main",
|
||||
Progress: func(p ImportProgress) { progress = append(progress, p) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt)
|
||||
require.Contains(t, progressPhases(progress), "start")
|
||||
require.Contains(t, progressPhases(progress), "table_start")
|
||||
require.Contains(t, progressPhases(progress), "file_done")
|
||||
require.Contains(t, progressPhases(progress), "rebuild_fts")
|
||||
require.Contains(t, progressPhases(progress), "done")
|
||||
|
||||
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "launch", Limit: 10})
|
||||
require.NoError(t, err)
|
||||
@ -673,7 +683,7 @@ func TestShareSmallHelpersAndValidation(t *testing.T) {
|
||||
defer func() { _ = s.Close() }()
|
||||
tx, err := s.DB().BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.ErrorContains(t, importTable(ctx, tx, t.TempDir(), TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files")
|
||||
require.ErrorContains(t, importTable(ctx, tx, Options{RepoPath: t.TempDir()}, TableManifest{Name: "messages", Columns: []string{"id"}}), "has no files")
|
||||
require.NoError(t, tx.Rollback())
|
||||
|
||||
require.ErrorContains(t, ImportEmbeddings(ctx, s, Options{
|
||||
@ -727,7 +737,7 @@ func TestLegacyManifestFileImportAndEmbeddingDecodeErrors(t *testing.T) {
|
||||
})
|
||||
tx, err := s.DB().BeginTx(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, importTable(ctx, tx, repo, TableManifest{
|
||||
require.NoError(t, importTable(ctx, tx, Options{RepoPath: repo}, TableManifest{
|
||||
Name: "guilds",
|
||||
File: tableRel,
|
||||
Columns: []string{"id", "name", "icon", "raw_json", "updated_at"},
|
||||
@ -935,3 +945,11 @@ func tableNames(manifest Manifest) []string {
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func progressPhases(progress []ImportProgress) []string {
|
||||
phases := make([]string, 0, len(progress))
|
||||
for _, item := range progress {
|
||||
phases = append(phases, item.Phase)
|
||||
}
|
||||
return phases
|
||||
}
|
||||
|
||||
@ -413,6 +413,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
|
||||
`create index if not exists idx_members_guild_id on members(guild_id);`,
|
||||
`create index if not exists idx_messages_channel_id on messages(channel_id);`,
|
||||
`create index if not exists idx_messages_guild_id on messages(guild_id);`,
|
||||
`create index if not exists idx_messages_created_id on messages(created_at, id);`,
|
||||
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
|
||||
@ -488,6 +489,7 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
|
||||
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_created_id on messages(created_at, id);`,
|
||||
`create index if not exists idx_mentions_guild_event on mention_events(guild_id, event_at, event_id);`,
|
||||
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
|
||||
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,
|
||||
|
||||
@ -1337,6 +1337,7 @@ func TestOpenCreatesQueryIndexes(t *testing.T) {
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
messageIndexes := indexNames(t, ctx, s.DB(), "messages")
|
||||
require.Contains(t, messageIndexes, "idx_messages_created_id")
|
||||
require.Contains(t, messageIndexes, "idx_messages_guild_created_id")
|
||||
require.Contains(t, messageIndexes, "idx_messages_channel_created_id")
|
||||
require.Contains(t, messageIndexes, "idx_messages_author_created_id")
|
||||
@ -1361,6 +1362,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) {
|
||||
"idx_messages_guild_created_id",
|
||||
"idx_messages_channel_created_id",
|
||||
"idx_messages_author_created_id",
|
||||
"idx_messages_created_id",
|
||||
"idx_mentions_guild_event",
|
||||
"idx_mentions_channel_event",
|
||||
} {
|
||||
@ -1376,6 +1378,7 @@ func TestOpenMigratesLegacyQueryIndexes(t *testing.T) {
|
||||
version, err := s.schemaVersion(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, storeSchemaVersion, version)
|
||||
require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_created_id")
|
||||
require.Contains(t, indexNames(t, ctx, s.DB(), "messages"), "idx_messages_channel_created_id")
|
||||
require.Contains(t, indexNames(t, ctx, s.DB(), "mention_events"), "idx_mentions_guild_event")
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user