test: cover wiretap cache checkpoint helpers

This commit is contained in:
Peter Steinberger 2026-05-05 01:43:13 +01:00
parent 78fcca8204
commit 6ea543b4c6
No known key found for this signature in database
5 changed files with 215 additions and 50 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io"
"io/fs"
"maps"
"os"
"path/filepath"
"regexp"
@ -31,8 +32,10 @@ const (
checkpointEveryFiles = 256
)
var channelRouteRE = regexp.MustCompile(`/channels/(@me|[0-9]{12,24})/([0-9]{12,24})`)
var apiMessagesRouteRE = regexp.MustCompile(`/api/v[0-9]+/channels/[0-9]{12,24}/messages`)
var (
channelRouteRE = regexp.MustCompile(`/channels/(@me|[0-9]{12,24})/([0-9]{12,24})`)
apiMessagesRouteRE = regexp.MustCompile(`/api/v[0-9]+/channels/[0-9]{12,24}/messages`)
)
type Options struct {
Path string
@ -616,9 +619,7 @@ func finalizeSnapshot(snap snapshot, channelLookup map[string]store.ChannelRecor
}
func mergeUnresolved(dst, src unresolvedMessages) {
for messageID, channelID := range src {
dst[messageID] = channelID
}
maps.Copy(dst, src)
}
func recordUnresolved(unresolved unresolvedMessages, totals scanTotals, stats *Stats) {
@ -701,12 +702,8 @@ func newSnapshot() snapshot {
func newSnapshotWithContext(base snapshot) snapshot {
snap := newSnapshot()
for channelID, guildID := range base.routes {
snap.routes[channelID] = guildID
}
for userID, label := range base.userLabels {
snap.userLabels[userID] = label
}
maps.Copy(snap.routes, base.routes)
maps.Copy(snap.userLabels, base.userLabels)
return snap
}
@ -714,19 +711,13 @@ func mergeSnapshotContext(base snapshot, next snapshot) {
for channelID, guildID := range next.routes {
collectChannelRoute(base, channelID, guildID)
}
for userID, label := range next.userLabels {
base.userLabels[userID] = label
}
for channelID, channel := range next.channels {
base.channels[channelID] = channel
}
maps.Copy(base.userLabels, next.userLabels)
maps.Copy(base.channels, next.channels)
}
func copyChannelLookup(in map[string]store.ChannelRecord) map[string]store.ChannelRecord {
out := make(map[string]store.ChannelRecord, len(in))
for id, channel := range in {
out[id] = channel
}
maps.Copy(out, in)
return out
}

View File

@ -0,0 +1,92 @@
package discorddesktop
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/steipete/discrawl/internal/store"
)
func TestFileFingerprintStatusHelpers(t *testing.T) {
base := fileFingerprint{Size: 123, ModUnixNS: 456}
require.True(t, sameFileFingerprint(base, fileFingerprint{Size: 123, ModUnixNS: 456, Status: fileStatusSkipped}))
require.False(t, sameFileFingerprint(base, fileFingerprint{Size: 124, ModUnixNS: 456}))
require.False(t, sameFileFingerprint(base, fileFingerprint{Size: 123, ModUnixNS: 457}))
require.True(t, isImportedFingerprint(base))
require.True(t, isImportedFingerprint(importedFingerprint(base)))
require.False(t, isImportedFingerprint(skippedFingerprint(base)))
require.Equal(t, fileStatusImported, importedFingerprint(base).Status)
require.Equal(t, fileStatusSkipped, skippedFingerprint(base).Status)
require.Equal(t, wiretapFileIndexScope, fileIndexScope(Options{}))
require.Equal(t, wiretapFileIndexScope, fileIndexScope(Options{FullCache: true}))
}
func TestSnapshotCopyHelpers(t *testing.T) {
base := newSnapshot()
base.routes["111111111111111121"] = "999999999999999996"
base.userLabels["222222222222222232"] = userLabel{Name: "Alice"}
base.channels["111111111111111121"] = store.ChannelRecord{ID: "111111111111111121", GuildID: "999999999999999996", Name: "general"}
snap := newSnapshotWithContext(base)
require.Equal(t, base.routes, snap.routes)
require.Equal(t, base.userLabels, snap.userLabels)
require.Empty(t, snap.channels)
next := newSnapshot()
next.routes["111111111111111122"] = "999999999999999996"
next.userLabels["222222222222222233"] = userLabel{Name: "Bob"}
next.channels["111111111111111122"] = store.ChannelRecord{ID: "111111111111111122", GuildID: "999999999999999996", Name: "random"}
mergeSnapshotContext(base, next)
require.Equal(t, "999999999999999996", base.routes["111111111111111122"])
require.Equal(t, "Bob", base.userLabels["222222222222222233"].Name)
require.Equal(t, "random", base.channels["111111111111111122"].Name)
lookup := copyChannelLookup(base.channels)
lookup["111111111111111122"] = store.ChannelRecord{ID: "changed"}
require.Equal(t, "random", base.channels["111111111111111122"].Name)
}
func TestSnapshotWithoutMessageEvents(t *testing.T) {
snap := newSnapshot()
snap.messages["333333333333333346"] = store.MessageMutation{
Record: store.MessageRecord{ID: "333333333333333346"},
Options: store.WriteOptions{
AppendEvent: true,
EnqueueEmbedding: true,
},
}
stripped := snapshotWithoutMessageEvents(snap)
require.False(t, stripped.messages["333333333333333346"].Options.AppendEvent)
require.True(t, stripped.messages["333333333333333346"].Options.EnqueueEmbedding)
require.True(t, snap.messages["333333333333333346"].Options.AppendEvent)
}
func TestRouteFilteredCacheHelpers(t *testing.T) {
require.Equal(t, fileSourceCacheData, sourceForPath("/tmp/discord", "/tmp/discord/Cache/Cache_Data/entry", "Cache/Cache_Data/entry"))
require.Equal(t, fileSourceCacheData, sourceForPath("/tmp/discord", "/tmp/discord/Service Worker/CacheStorage/cache/entry", "Service Worker/CacheStorage/cache/entry"))
require.Equal(t, fileSourceContext, sourceForPath("/tmp/discord", "/tmp/discord/Local Storage/leveldb/000001.log", "Local Storage/leveldb/000001.log"))
}
func TestCacheFileHasRouteHint(t *testing.T) {
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "route"), []byte("https://discord.com/api/v9/channels/111111111111111121/messages?limit=50"), 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir, "plain"), []byte("no discord route here"), 0o600))
root, err := os.OpenRoot(dir)
require.NoError(t, err)
defer func() { _ = root.Close() }()
ok, err := cacheFileHasRouteHint(root, "route")
require.NoError(t, err)
require.True(t, ok)
ok, err = cacheFileHasRouteHint(root, "plain")
require.NoError(t, err)
require.False(t, ok)
_, err = cacheFileHasRouteHint(root, "missing")
require.Error(t, err)
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/require"
@ -71,9 +72,9 @@ func TestImportCheckpointsCacheBatches(t *testing.T) {
for i := range checkpointEveryFiles + 1 {
channelID := "111111111111111121"
messageID := 333333333333333346 + i
body := []byte(fmt.Sprintf(`https://discord.com/channels/999999999999999996/%s
body := bytesf(`https://discord.com/channels/999999999999999996/%s
{"id":"%d","channel_id":"%s","content":"checkpoint cache %d","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
`, channelID, messageID, channelID, i))
`, channelID, messageID, channelID, i)
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), body, 0o600))
}
@ -101,18 +102,18 @@ func TestImportUsesLaterCacheMetadataBeforeCheckpointingEarlierBatch(t *testing.
channelID := "111111111111111121"
guildID := "999999999999999996"
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"333333333333333346","channel_id":"%s","content":"needs later channel metadata","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
`, channelID, channelID)), 0o600))
`, channelID, channelID), 0o600))
for i := 1; i < checkpointEveryFiles; i++ {
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), []byte(fmt.Sprintf(
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), bytesf(
"https://discord.com/api/v9/channels/%s/messages?limit=50\n",
channelID,
)), 0o600))
), 0o600))
}
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"%s","guild_id":"%s","type":0,"name":"later-metadata"}
`, channelID, channelID, guildID)), 0o600))
`, channelID, channelID, guildID), 0o600))
st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
require.NoError(t, err)
@ -147,20 +148,20 @@ func TestImportCheckpointsPartiallyResolvedRetryBatch(t *testing.T) {
resolvedChannelID := "111111111111111121"
unresolvedChannelID := "111111111111111122"
guildID := "999999999999999996"
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v10/channels/%s/messages?limit=50
https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"333333333333333346","channel_id":"%s","content":"partially resolved retry message","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
{"id":"333333333333333347","channel_id":"%s","content":"still unresolved retry message","timestamp":"2026-04-23T18:20:44Z","author":{"id":"222222222222222232","username":"alice"}}
`, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID)), 0o600))
`, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID), 0o600))
for i := 1; i < checkpointEveryFiles; i++ {
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), []byte(fmt.Sprintf(
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", i)), bytesf(
"https://discord.com/api/v9/channels/%s/messages?limit=50\n",
resolvedChannelID,
)), 0o600))
), 0o600))
}
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50
require.NoError(t, os.WriteFile(filepath.Join(cachePath, fmt.Sprintf("entry_%03d", checkpointEveryFiles)), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"%s","guild_id":"%s","type":0,"name":"partially-resolved"}
`, resolvedChannelID, resolvedChannelID, guildID)), 0o600))
`, resolvedChannelID, resolvedChannelID, guildID), 0o600))
st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
require.NoError(t, err)
@ -196,9 +197,9 @@ func TestImportCheckpointsUnresolvableRouteBearingCacheMisses(t *testing.T) {
require.NoError(t, os.MkdirAll(cachePath, 0o755))
channelID := "111111111111111121"
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/api/v9/channels/%s/messages?limit=50
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"333333333333333346","channel_id":"%s","content":"permanent unresolved cache miss","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
`, channelID, channelID)), 0o600))
`, channelID, channelID), 0o600))
st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
require.NoError(t, err)
@ -229,11 +230,11 @@ func TestImportDoesNotAppendEventsForSkippedMixedBatch(t *testing.T) {
guildID := "999999999999999996"
resolvedChannelID := "111111111111111121"
unresolvedChannelID := "111111111111111122"
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/channels/%s/%s
https://discord.com/api/v9/channels/%s/messages?limit=50
{"id":"333333333333333346","channel_id":"%s","content":"mixed resolved message","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
{"id":"333333333333333347","channel_id":"%s","content":"mixed unresolved message","timestamp":"2026-04-23T18:20:44Z","author":{"id":"222222222222222232","username":"alice"}}
`, guildID, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID)), 0o600))
`, guildID, resolvedChannelID, unresolvedChannelID, resolvedChannelID, unresolvedChannelID), 0o600))
st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
require.NoError(t, err)
@ -267,10 +268,10 @@ func TestImportDoesNotDuplicateEventsWhenSwitchingFullCacheModes(t *testing.T) {
channelID := "111111111111111121"
guildID := "999999999999999996"
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_000"), bytesf(`https://discord.com/channels/%s/%s
{"id":"%s","guild_id":"%s","type":0,"name":"mode-switch"}
{"id":"333333333333333346","channel_id":"%s","content":"mode switch event once","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
`, guildID, channelID, channelID, guildID, channelID)), 0o600))
`, guildID, channelID, channelID, guildID, channelID), 0o600))
t.Run("full then default", func(t *testing.T) {
st, err := store.Open(ctx, filepath.Join(dir, "full-first.db"))
@ -319,14 +320,14 @@ func TestImportFastCachePreservesKnownChannelMetadataAcrossBatches(t *testing.T)
channelID := "111111111111111121"
guildID := "999999999999999996"
require.NoError(t, os.WriteFile(filepath.Join(leveldbPath, "000001.log"), []byte(fmt.Sprintf(
require.NoError(t, os.WriteFile(filepath.Join(leveldbPath, "000001.log"), bytesf(
`{"id":"%s","guild_id":"%s","type":11,"name":"known-thread","thread_metadata":{"archived":false}}`,
channelID,
guildID,
)), 0o600))
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_0"), []byte(fmt.Sprintf(`https://discord.com/channels/%s/%s
), 0o600))
require.NoError(t, os.WriteFile(filepath.Join(cachePath, "entry_0"), bytesf(`https://discord.com/channels/%s/%s
{"id":"333333333333333346","channel_id":"%s","content":"thread metadata cache","timestamp":"2026-04-23T18:20:43Z","author":{"id":"222222222222222232","username":"alice"}}
`, guildID, channelID, channelID)), 0o600))
`, guildID, channelID, channelID), 0o600))
st, err := store.Open(ctx, filepath.Join(dir, "discrawl.db"))
require.NoError(t, err)
@ -374,9 +375,13 @@ func TestImportFastCacheRouteFiltersServiceWorkerCacheStorage(t *testing.T) {
func requireMessageCount(t *testing.T, ctx context.Context, st *store.Store, table string, expected int) {
t.Helper()
_, rows, err := st.ReadOnlyQuery(ctx, fmt.Sprintf("select count(*) from %s", table))
_, rows, err := st.ReadOnlyQuery(ctx, "select count(*) from "+table)
require.NoError(t, err)
require.Len(t, rows, 1)
require.Len(t, rows[0], 1)
require.Equal(t, fmt.Sprint(expected), rows[0][0])
require.Equal(t, strconv.Itoa(expected), rows[0][0])
}
func bytesf(format string, args ...any) []byte {
return fmt.Appendf(nil, format, args...)
}

View File

@ -49,10 +49,7 @@ func (r *importRun) scanContext(candidates []fileCandidate) error {
func (r *importRun) scanCacheBatches(candidates []fileCandidate) error {
for start := 0; start < len(candidates); start += checkpointEveryFiles {
end := start + checkpointEveryFiles
if end > len(candidates) {
end = len(candidates)
}
end := min(start+checkpointEveryFiles, len(candidates))
batchCandidates := candidates[start:end]
batch := newSnapshotWithContext(r.base)
if err := scanCandidates(r.ctx, r.rootFS, r.opts, batchCandidates, batch, r.channelLookup, r.stats); err != nil {

View File

@ -0,0 +1,80 @@
package discorddesktop
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestPrimitiveValueHelpers(t *testing.T) {
raw := map[string]any{
"string": "value",
"blank": " ",
"int": 3,
"int64": int64(4),
"float": float64(5),
"json_number": json.Number("6"),
"numeric": "7",
"bad_numeric": "nope",
"truthy": true,
"array": []any{"one", "two"},
}
require.Equal(t, "value", stringField(raw, "string"))
require.Empty(t, stringField(raw, "blank"))
require.Equal(t, "6", stringField(raw, "json_number"))
require.Empty(t, stringField(raw, "int"))
require.Empty(t, stringField(raw, "missing"))
for key, want := range map[string]int{
"int": 3,
"float": 5,
"json_number": 6,
} {
got, ok := intField(raw, key)
require.True(t, ok, key)
require.Equal(t, want, got, key)
}
_, ok := intField(raw, "bad_numeric")
require.False(t, ok)
_, ok = intField(raw, "int64")
require.False(t, ok)
_, ok = intField(raw, "numeric")
require.False(t, ok)
_, ok = intField(raw, "missing")
require.False(t, ok)
require.Equal(t, int64(3), int64Field(raw, "int"))
require.Equal(t, int64(4), int64Field(raw, "int64"))
require.Equal(t, int64(5), int64Field(raw, "float"))
require.Equal(t, int64(6), int64Field(raw, "json_number"))
require.Zero(t, int64Field(raw, "numeric"))
require.Zero(t, int64Field(raw, "bad_numeric"))
require.True(t, boolField(raw, "truthy"))
require.False(t, boolField(raw, "missing"))
require.Equal(t, 2, lenArray(raw["array"]))
require.Zero(t, lenArray(raw["string"]))
require.Equal(t, "fallback", firstNonEmpty("", " ", "fallback", "later"))
require.Empty(t, firstNonEmpty("", " "))
}
func TestDiscordValueFormatHelpers(t *testing.T) {
require.Equal(t, "456789", shortID("123456789"))
require.Equal(t, "short", shortID("short"))
require.Equal(t, "Discord Direct Messages", guildName(DirectMessageGuildID))
require.Equal(t, "Discord Desktop Guild 123456", guildName("123456"))
require.Equal(t, "dm", kindForChannelType(1, true))
require.Equal(t, "group_dm", kindForChannelType(3, true))
require.Equal(t, "thread_public", kindForChannelType(11, false))
require.Equal(t, "thread_private", kindForChannelType(12, false))
require.Equal(t, "thread_announcement", kindForChannelType(10, false))
require.Equal(t, "desktop", kindForChannelType(2, false))
require.Equal(t, "desktop", kindForChannelType(4, false))
require.Equal(t, "announcement", kindForChannelType(5, false))
require.Equal(t, "forum", kindForChannelType(15, false))
require.Equal(t, "desktop", kindForChannelType(16, false))
require.Equal(t, "text", kindForChannelType(0, false))
}