diff --git a/internal/discorddesktop/import.go b/internal/discorddesktop/import.go index 4c9b761..53bc203 100644 --- a/internal/discorddesktop/import.go +++ b/internal/discorddesktop/import.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/fs" + "maps" "os" "path/filepath" "regexp" @@ -31,8 +32,10 @@ const ( checkpointEveryFiles = 256 ) -var channelRouteRE = regexp.MustCompile(`/channels/(@me|[0-9]{12,24})/([0-9]{12,24})`) -var apiMessagesRouteRE = regexp.MustCompile(`/api/v[0-9]+/channels/[0-9]{12,24}/messages`) +var ( + channelRouteRE = regexp.MustCompile(`/channels/(@me|[0-9]{12,24})/([0-9]{12,24})`) + apiMessagesRouteRE = regexp.MustCompile(`/api/v[0-9]+/channels/[0-9]{12,24}/messages`) +) type Options struct { Path string @@ -616,9 +619,7 @@ func finalizeSnapshot(snap snapshot, channelLookup map[string]store.ChannelRecor } func mergeUnresolved(dst, src unresolvedMessages) { - for messageID, channelID := range src { - dst[messageID] = channelID - } + maps.Copy(dst, src) } func recordUnresolved(unresolved unresolvedMessages, totals scanTotals, stats *Stats) { @@ -701,12 +702,8 @@ func newSnapshot() snapshot { func newSnapshotWithContext(base snapshot) snapshot { snap := newSnapshot() - for channelID, guildID := range base.routes { - snap.routes[channelID] = guildID - } - for userID, label := range base.userLabels { - snap.userLabels[userID] = label - } + maps.Copy(snap.routes, base.routes) + maps.Copy(snap.userLabels, base.userLabels) return snap } @@ -714,19 +711,13 @@ func mergeSnapshotContext(base snapshot, next snapshot) { for channelID, guildID := range next.routes { collectChannelRoute(base, channelID, guildID) } - for userID, label := range next.userLabels { - base.userLabels[userID] = label - } - for channelID, channel := range next.channels { - base.channels[channelID] = channel - } + maps.Copy(base.userLabels, next.userLabels) + maps.Copy(base.channels, next.channels) } func copyChannelLookup(in map[string]store.ChannelRecord) map[string]store.ChannelRecord { out := make(map[string]store.ChannelRecord, len(in)) - for id, channel := range in { - out[id] = channel - } + maps.Copy(out, in) return out } diff --git a/internal/discorddesktop/import_helpers_test.go b/internal/discorddesktop/import_helpers_test.go new file mode 100644 index 0000000..749adb1 --- /dev/null +++ b/internal/discorddesktop/import_helpers_test.go @@ -0,0 +1,92 @@ +package discorddesktop + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/steipete/discrawl/internal/store" +) + +func TestFileFingerprintStatusHelpers(t *testing.T) { + base := fileFingerprint{Size: 123, ModUnixNS: 456} + require.True(t, sameFileFingerprint(base, fileFingerprint{Size: 123, ModUnixNS: 456, Status: fileStatusSkipped})) + require.False(t, sameFileFingerprint(base, fileFingerprint{Size: 124, ModUnixNS: 456})) + require.False(t, sameFileFingerprint(base, fileFingerprint{Size: 123, ModUnixNS: 457})) + + require.True(t, isImportedFingerprint(base)) + require.True(t, isImportedFingerprint(importedFingerprint(base))) + require.False(t, isImportedFingerprint(skippedFingerprint(base))) + require.Equal(t, fileStatusImported, importedFingerprint(base).Status) + require.Equal(t, fileStatusSkipped, skippedFingerprint(base).Status) + require.Equal(t, wiretapFileIndexScope, fileIndexScope(Options{})) + require.Equal(t, wiretapFileIndexScope, fileIndexScope(Options{FullCache: true})) +} + +func TestSnapshotCopyHelpers(t *testing.T) { + base := newSnapshot() + base.routes["111111111111111121"] = "999999999999999996" + base.userLabels["222222222222222232"] = userLabel{Name: "Alice"} + base.channels["111111111111111121"] = store.ChannelRecord{ID: "111111111111111121", GuildID: "999999999999999996", Name: "general"} + + snap := newSnapshotWithContext(base) + require.Equal(t, base.routes, snap.routes) + require.Equal(t, base.userLabels, snap.userLabels) + require.Empty(t, snap.channels) + + next := newSnapshot() + next.routes["111111111111111122"] = "999999999999999996" + next.userLabels["222222222222222233"] = userLabel{Name: "Bob"} + next.channels["111111111111111122"] = store.ChannelRecord{ID: "111111111111111122", GuildID: "999999999999999996", Name: "random"} + mergeSnapshotContext(base, next) + + require.Equal(t, "999999999999999996", base.routes["111111111111111122"]) + require.Equal(t, "Bob", base.userLabels["222222222222222233"].Name) + require.Equal(t, "random", base.channels["111111111111111122"].Name) + + lookup := copyChannelLookup(base.channels) + lookup["111111111111111122"] = store.ChannelRecord{ID: "changed"} + require.Equal(t, "random", base.channels["111111111111111122"].Name) +} + +func TestSnapshotWithoutMessageEvents(t *testing.T) { + snap := newSnapshot() + snap.messages["333333333333333346"] = store.MessageMutation{ + Record: store.MessageRecord{ID: "333333333333333346"}, + Options: store.WriteOptions{ + AppendEvent: true, + EnqueueEmbedding: true, + }, + } + stripped := snapshotWithoutMessageEvents(snap) + require.False(t, stripped.messages["333333333333333346"].Options.AppendEvent) + require.True(t, stripped.messages["333333333333333346"].Options.EnqueueEmbedding) + require.True(t, snap.messages["333333333333333346"].Options.AppendEvent) +} + +func TestRouteFilteredCacheHelpers(t *testing.T) { + require.Equal(t, fileSourceCacheData, sourceForPath("/tmp/discord", "/tmp/discord/Cache/Cache_Data/entry", "Cache/Cache_Data/entry")) + require.Equal(t, fileSourceCacheData, sourceForPath("/tmp/discord", "/tmp/discord/Service Worker/CacheStorage/cache/entry", "Service Worker/CacheStorage/cache/entry")) + require.Equal(t, fileSourceContext, sourceForPath("/tmp/discord", "/tmp/discord/Local Storage/leveldb/000001.log", "Local Storage/leveldb/000001.log")) +} + +func TestCacheFileHasRouteHint(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "route"), []byte("https://discord.com/api/v9/channels/111111111111111121/messages?limit=50"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "plain"), []byte("no discord route here"), 0o600)) + + root, err := os.OpenRoot(dir) + require.NoError(t, err) + defer func() { _ = root.Close() }() + + ok, err := cacheFileHasRouteHint(root, "route") + require.NoError(t, err) + require.True(t, ok) + ok, err = cacheFileHasRouteHint(root, "plain") + require.NoError(t, err) + require.False(t, ok) + _, err = cacheFileHasRouteHint(root, "missing") + require.Error(t, err) +} diff --git a/internal/discorddesktop/import_pipeline_test.go b/internal/discorddesktop/import_pipeline_test.go index c8bf974..79cd5b1 100644 --- a/internal/discorddesktop/import_pipeline_test.go +++ b/internal/discorddesktop/import_pipeline_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "testing" "github.com/stretchr/testify/require" @@ -71,9 +72,9 @@ func TestImportCheckpointsCacheBatches(t *testing.T) { for i := range checkpointEveryFiles + 1 { channelID := "111111111111111121" messageID := 333333333333333346 + i - body := []byte(fmt.Sprintf(`https://discord.com/channels/999999999999999996/%s + body := bytesf(`https://discord.com/channels/999999999999999996/%s {"id":"%d","channel_id":"%s","content":"checkpoint cache %d","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} -`, channelID, messageID, channelID, i)) +`, channelID, messageID, channelID, i) require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), body, 0o600)) } @@ -101,18 +102,18 @@ func TestImportUsesLaterCacheMetadataBeforeCheckpointingEarlierBatch(t *testing. channelID := "111111111111111121" guildID := "999999999999999996" - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50 + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"333333333333333346","channel_id":"%s","content":"needs later channel metadata","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} -`, channelID, channelID)), 0o600)) +`, channelID, channelID), 0o600)) for i := 1; i < checkpointEveryFiles; i++ { - require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), []byte(fmt.Sprintf( + require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), bytesf( "https://discord.com/api/v9/channels/%s/messages?limit=50\n", channelID, - )), 0o600)) + ), 0o600)) } - require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50 + require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"%s","guild_id":"%s","type":0,"name":"later-metadata"} -`, channelID, channelID, guildID)), 0o600)) +`, channelID, channelID, guildID), 0o600)) st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db")) require.NoError(t, err) @@ -147,20 +148,20 @@ func TestImportCheckpointsPartiallyResolvedRetryBatch(t *testing.T) { resolvedChannelID := "111111111111111121" unresolvedChannelID := "111111111111111122" guildID := "999999999999999996" - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50 + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v10/channels/%s/messages?limit=50 https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"333333333333333346","channel_id":"%s","content":"partially resolved retry message","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} {"id":"333333333333333347","channel_id":"%s","content":"still unresolved retry message","timestamp":"2026-04-23T18:20:44Z","author":{"id":"222222222222222232","username":"alice"}} -`, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID)), 0o600)) +`, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID), 0o600)) for i := 1; i < checkpointEveryFiles; i++ { - require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), []byte(fmt.Sprintf( + require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), bytesf( "https://discord.com/api/v9/channels/%s/messages?limit=50\n", resolvedChannelID, - )), 0o600)) + ), 0o600)) } - require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50 + require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"%s","guild_id":"%s","type":0,"name":"partially-resolved"} -`, resolvedChannelID, resolvedChannelID, guildID)), 0o600)) +`, resolvedChannelID, resolvedChannelID, guildID), 0o600)) st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db")) require.NoError(t, err) @@ -196,9 +197,9 @@ func TestImportCheckpointsUnresolvableRouteBearingCacheMisses(t *testing.T) { require.NoError(t, os.MkdirAll(cachePath, 0o755)) channelID := "111111111111111121" - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50 + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"333333333333333346","channel_id":"%s","content":"permanent unresolved cache miss","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} -`, channelID, channelID)), 0o600)) +`, channelID, channelID), 0o600)) st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db")) require.NoError(t, err) @@ -229,11 +230,11 @@ func TestImportDoesNotAppendEventsForSkippedMixedBatch(t *testing.T) { guildID := "999999999999999996" resolvedChannelID := "111111111111111121" unresolvedChannelID := "111111111111111122" - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/channels/%s/%s https://discord.com/api/v9/channels/%s/messages?limit=50 {"id":"333333333333333346","channel_id":"%s","content":"mixed resolved message","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} {"id":"333333333333333347","channel_id":"%s","content":"mixed unresolved message","timestamp":"2026-04-23T18:20:44Z","author":{"id":"222222222222222232","username":"alice"}} -`, guildID, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID)), 0o600)) +`, guildID, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID), 0o600)) st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db")) require.NoError(t, err) @@ -267,10 +268,10 @@ func TestImportDoesNotDuplicateEventsWhenSwitchingFullCacheModes(t *testing.T) { channelID := "111111111111111121" guildID := "999999999999999996" - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/channels/%s/%s {"id":"%s","guild_id":"%s","type":0,"name":"mode-switch"} {"id":"333333333333333346","channel_id":"%s","content":"mode switch event once","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} -`, guildID, channelID, channelID, guildID, channelID)), 0o600)) +`, guildID, channelID, channelID, guildID, channelID), 0o600)) t.Run("full then default", func(t *testing.T) { st, err := store.Open(ctx, filepath.Join(dir, "full-first.db")) @@ -319,14 +320,14 @@ func TestImportFastCachePreservesKnownChannelMetadataAcrossBatches(t *testing.T) channelID := "111111111111111121" guildID := "999999999999999996" - require.NoError(t, os.WriteFile(filepath.Join(leveldbPath, "000001.log"), []byte(fmt.Sprintf( + require.NoError(t, os.WriteFile(filepath.Join(leveldbPath, "000001.log"), bytesf( `{"id":"%s","guild_id":"%s","type":11,"name":"known-thread","thread_metadata":{"archived":false}}`, channelID, guildID, - )), 0o600)) - require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_0"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s + ), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_0"), bytesf(`https://discord.com/channels/%s/%s {"id":"333333333333333346","channel_id":"%s","content":"thread metadata cache","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}} -`, guildID, channelID, channelID)), 0o600)) +`, guildID, channelID, channelID), 0o600)) st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db")) require.NoError(t, err) @@ -374,9 +375,13 @@ func TestImportFastCacheRouteFiltersServiceWorkerCacheStorage(t *testing.T) { func requireMessageCount(t *testing.T, ctx context.Context, st *store.Store, table string, expected int) { t.Helper() - _, rows, err := st.ReadOnlyQuery(ctx, fmt.Sprintf("select count(*) from %s", table)) + _, rows, err := st.ReadOnlyQuery(ctx, "select count(*) from "+table) require.NoError(t, err) require.Len(t, rows, 1) require.Len(t, rows[0], 1) - require.Equal(t, fmt.Sprint(expected), rows[0][0]) + require.Equal(t, strconv.Itoa(expected), rows[0][0]) +} + +func bytesf(format string, args ...any) []byte { + return fmt.Appendf(nil, format, args...) } diff --git a/internal/discorddesktop/import_run.go b/internal/discorddesktop/import_run.go index 0142836..cc35727 100644 --- a/internal/discorddesktop/import_run.go +++ b/internal/discorddesktop/import_run.go @@ -49,10 +49,7 @@ func (r *importRun) scanContext(candidates []fileCandidate) error { func (r *importRun) scanCacheBatches(candidates []fileCandidate) error { for start := 0; start < len(candidates); start += checkpointEveryFiles { - end := start + checkpointEveryFiles - if end > len(candidates) { - end = len(candidates) - } + end := min(start+checkpointEveryFiles, len(candidates)) batchCandidates := candidates[start:end] batch := newSnapshotWithContext(r.base) if err := scanCandidates(r.ctx, r.rootFS, r.opts, batchCandidates, batch, r.channelLookup, r.stats); err != nil { diff --git a/internal/discorddesktop/import_value_helpers_test.go b/internal/discorddesktop/import_value_helpers_test.go new file mode 100644 index 0000000..30943cc --- /dev/null +++ b/internal/discorddesktop/import_value_helpers_test.go @@ -0,0 +1,80 @@ +package discorddesktop + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrimitiveValueHelpers(t *testing.T) { + raw := map[string]any{ + "string": "value", + "blank": " ", + "int": 3, + "int64": int64(4), + "float": float64(5), + "json_number": json.Number("6"), + "numeric": "7", + "bad_numeric": "nope", + "truthy": true, + "array": []any{"one", "two"}, + } + + require.Equal(t, "value", stringField(raw, "string")) + require.Empty(t, stringField(raw, "blank")) + require.Equal(t, "6", stringField(raw, "json_number")) + require.Empty(t, stringField(raw, "int")) + require.Empty(t, stringField(raw, "missing")) + + for key, want := range map[string]int{ + "int": 3, + "float": 5, + "json_number": 6, + } { + got, ok := intField(raw, key) + require.True(t, ok, key) + require.Equal(t, want, got, key) + } + _, ok := intField(raw, "bad_numeric") + require.False(t, ok) + _, ok = intField(raw, "int64") + require.False(t, ok) + _, ok = intField(raw, "numeric") + require.False(t, ok) + _, ok = intField(raw, "missing") + require.False(t, ok) + + require.Equal(t, int64(3), int64Field(raw, "int")) + require.Equal(t, int64(4), int64Field(raw, "int64")) + require.Equal(t, int64(5), int64Field(raw, "float")) + require.Equal(t, int64(6), int64Field(raw, "json_number")) + require.Zero(t, int64Field(raw, "numeric")) + require.Zero(t, int64Field(raw, "bad_numeric")) + + require.True(t, boolField(raw, "truthy")) + require.False(t, boolField(raw, "missing")) + require.Equal(t, 2, lenArray(raw["array"])) + require.Zero(t, lenArray(raw["string"])) + require.Equal(t, "fallback", firstNonEmpty("", " ", "fallback", "later")) + require.Empty(t, firstNonEmpty("", " ")) +} + +func TestDiscordValueFormatHelpers(t *testing.T) { + require.Equal(t, "456789", shortID("123456789")) + require.Equal(t, "short", shortID("short")) + require.Equal(t, "Discord Direct Messages", guildName(DirectMessageGuildID)) + require.Equal(t, "Discord Desktop Guild 123456", guildName("123456")) + + require.Equal(t, "dm", kindForChannelType(1, true)) + require.Equal(t, "group_dm", kindForChannelType(3, true)) + require.Equal(t, "thread_public", kindForChannelType(11, false)) + require.Equal(t, "thread_private", kindForChannelType(12, false)) + require.Equal(t, "thread_announcement", kindForChannelType(10, false)) + require.Equal(t, "desktop", kindForChannelType(2, false)) + require.Equal(t, "desktop", kindForChannelType(4, false)) + require.Equal(t, "announcement", kindForChannelType(5, false)) + require.Equal(t, "forum", kindForChannelType(15, false)) + require.Equal(t, "desktop", kindForChannelType(16, false)) + require.Equal(t, "text", kindForChannelType(0, false)) +}