fix(share): harden git snapshot imports (#51)

* fix(share): harden git snapshot imports

* docs: update changelog for share import hardening
This commit is contained in:
Peter Steinberger 2026-04-29 15:19:44 +01:00 committed by GitHub
parent 59f42cb0ab
commit b387ed2d6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 170 additions and 29 deletions

View File

@ -6,6 +6,8 @@ All notable changes to `discrawl` will be documented in this file.
### Fixes
- Git snapshot imports now recover from corrupt local FTS tables by dropping and rebuilding search indexes, and repair missing guild IDs from channel metadata so shared archive reports stay fresh.
- Channel-history sync now falls back to the channel guild when Discord omits `message.guild_id`, keeping messages, attachments, mentions, and FTS rows correctly scoped.
- Repeated `sync --source wiretap` runs now skip unchanged Discord Desktop cache files and report unchanged file counts, making steady-state local-cache refreshes much faster.
- `sync --full --skip-members` now also skips member crawls when resuming incomplete stored channels, so backfills do not unexpectedly refresh the full guild member list.

View File

@ -241,8 +241,8 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
}
}()
for _, table := range []string{"message_fts", "member_fts"} {
if _, err := tx.ExecContext(ctx, "delete from "+table); err != nil {
return Manifest{}, fmt.Errorf("clear %s: %w", table, err)
if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil {
return Manifest{}, fmt.Errorf("drop %s: %w", table, err)
}
}
for i := len(SnapshotTables) - 1; i >= 0; i-- {
@ -257,6 +257,9 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
return Manifest{}, err
}
}
if err := repairImportedGuildIDs(ctx, tx); err != nil {
return Manifest{}, err
}
if opts.IncludeEmbeddings {
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
return Manifest{}, err
@ -618,6 +621,67 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
return nil
}
func repairImportedGuildIDs(ctx context.Context, tx *sql.Tx) error {
repairs := []struct {
table string
query string
}{
{"messages", `
update messages
set guild_id = (
select c.guild_id
from channels c
where c.id = messages.channel_id
)
where coalesce(guild_id, '') = ''
and exists (
select 1
from channels c
where c.id = messages.channel_id
and coalesce(c.guild_id, '') != ''
)`},
{"message_attachments", `
update message_attachments
set guild_id = coalesce(
nullif((select m.guild_id from messages m where m.id = message_attachments.message_id), ''),
(select c.guild_id from channels c where c.id = message_attachments.channel_id)
)
where coalesce(guild_id, '') = ''
and coalesce(
nullif((select m.guild_id from messages m where m.id = message_attachments.message_id), ''),
(select c.guild_id from channels c where c.id = message_attachments.channel_id)
) is not null`},
{"message_events", `
update message_events
set guild_id = coalesce(
nullif((select m.guild_id from messages m where m.id = message_events.message_id), ''),
(select c.guild_id from channels c where c.id = message_events.channel_id)
)
where coalesce(guild_id, '') = ''
and coalesce(
nullif((select m.guild_id from messages m where m.id = message_events.message_id), ''),
(select c.guild_id from channels c where c.id = message_events.channel_id)
) is not null`},
{"mention_events", `
update mention_events
set guild_id = coalesce(
nullif((select m.guild_id from messages m where m.id = mention_events.message_id), ''),
(select c.guild_id from channels c where c.id = mention_events.channel_id)
)
where coalesce(guild_id, '') = ''
and coalesce(
nullif((select m.guild_id from messages m where m.id = mention_events.message_id), ''),
(select c.guild_id from channels c where c.id = mention_events.channel_id)
) is not null`},
}
for _, repair := range repairs {
if _, err := tx.ExecContext(ctx, repair.query); err != nil {
return fmt.Errorf("repair imported %s guild ids: %w", repair.table, err)
}
}
return nil
}
func importColumns(table TableManifest) []string {
if table.Name != "message_events" && table.Name != "mention_events" {
return table.Columns

View File

@ -63,6 +63,41 @@ func TestExportImportRoundTrip(t *testing.T) {
require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt)
}
func TestImportRepairsBlankMessageGuildIDs(t *testing.T) {
ctx := context.Background()
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
defer func() { _ = src.Close() }()
_, err := src.DB().ExecContext(ctx, `update messages set guild_id = '' where id = 'm1'`)
require.NoError(t, err)
_, err = src.DB().ExecContext(ctx, `update message_events set guild_id = '' where message_id = 'm1'`)
require.NoError(t, err)
_, err = src.DB().ExecContext(ctx, `update mention_events set guild_id = '' where message_id = 'm1'`)
require.NoError(t, err)
repo := filepath.Join(t.TempDir(), "share")
_, err = Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
require.NoError(t, err)
require.Contains(t, snapshotTableText(t, repo, tableEntry(t, mustReadManifest(t, repo), "messages")), `"guild_id":""`)
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
require.NoError(t, err)
defer func() { _ = dst.Close() }()
_, err = Import(ctx, dst, Options{RepoPath: repo, Branch: "main"})
require.NoError(t, err)
var guildID string
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from messages where id = 'm1'`).Scan(&guildID))
require.Equal(t, "g1", guildID)
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from message_events where message_id = 'm1'`).Scan(&guildID))
require.Equal(t, "g1", guildID)
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from mention_events where message_id = 'm1'`).Scan(&guildID))
require.Equal(t, "g1", guildID)
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "launch", GuildIDs: []string{"g1"}, Limit: 10})
require.NoError(t, err)
require.Len(t, results, 1)
require.Equal(t, "g1", results[0].GuildID)
}
func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
ctx := context.Background()
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
@ -860,6 +895,13 @@ func configureGitUser(t *testing.T, repo string) {
require.NoError(t, exec.Command("git", "-C", repo, "config", "user.email", "discrawl@example.com").Run())
}
func mustReadManifest(t *testing.T, repo string) Manifest {
t.Helper()
manifest, err := ReadManifest(repo)
require.NoError(t, err)
return manifest
}
func tableEntry(t *testing.T, manifest Manifest, name string) TableManifest {
t.Helper()
for _, table := range manifest.Tables {

View File

@ -27,15 +27,17 @@ func buildMessageMutation(
ctx context.Context,
message *discordgo.Message,
channelName string,
fallbackGuildID string,
embeddings bool,
attachmentText bool,
) (store.MessageMutation, error) {
attachments, attachmentParts, err := extractAttachments(ctx, message, attachmentText)
guildID := effectiveMessageGuildID(message, fallbackGuildID)
attachments, attachmentParts, err := extractAttachments(ctx, message, guildID, attachmentText)
if err != nil {
return store.MessageMutation{}, err
}
normalized := normalizeMessageParts(message, attachmentParts)
record := toMessageRecord(message, channelName, normalized)
record := toMessageRecord(message, channelName, guildID, normalized)
return store.MessageMutation{
Record: record,
EventType: "upsert",
@ -44,11 +46,11 @@ func buildMessageMutation(
EnqueueEmbedding: embeddings,
},
Attachments: attachments,
Mentions: extractMentions(message),
Mentions: extractMentions(message, guildID),
}, nil
}
func extractAttachments(ctx context.Context, message *discordgo.Message, attachmentText bool) ([]store.AttachmentRecord, []string, error) {
func extractAttachments(ctx context.Context, message *discordgo.Message, guildID string, attachmentText bool) ([]store.AttachmentRecord, []string, error) {
if message == nil || len(message.Attachments) == 0 {
return nil, nil, nil
}
@ -61,7 +63,7 @@ func extractAttachments(ctx context.Context, message *discordgo.Message, attachm
record := store.AttachmentRecord{
AttachmentID: attachment.ID,
MessageID: message.ID,
GuildID: message.GuildID,
GuildID: guildID,
ChannelID: message.ChannelID,
Filename: attachment.Filename,
ContentType: attachment.ContentType,
@ -90,7 +92,7 @@ func extractAttachments(ctx context.Context, message *discordgo.Message, attachm
return records, parts, nil
}
func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
func extractMentions(message *discordgo.Message, guildID string) []store.MentionEventRecord {
if message == nil {
return nil
}
@ -116,7 +118,7 @@ func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
}
mentions = append(mentions, store.MentionEventRecord{
MessageID: message.ID,
GuildID: message.GuildID,
GuildID: guildID,
ChannelID: message.ChannelID,
AuthorID: authorID,
TargetType: "user",
@ -137,7 +139,7 @@ func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
seen[key] = struct{}{}
mentions = append(mentions, store.MentionEventRecord{
MessageID: message.ID,
GuildID: message.GuildID,
GuildID: guildID,
ChannelID: message.ChannelID,
AuthorID: authorID,
TargetType: "role",

View File

@ -44,7 +44,7 @@ func TestBuildMessageMutationIncludesAttachmentTextAndMentions(t *testing.T) {
GlobalName: "Shadow",
}},
MentionRoles: []string{"r1"},
}, "maintainers", false, true)
}, "maintainers", "", false, true)
require.NoError(t, err)
require.Len(t, mutation.Attachments, 1)
require.Equal(t, "trace.txt", mutation.Attachments[0].Filename)
@ -59,6 +59,30 @@ func TestBuildMessageMutationIncludesAttachmentTextAndMentions(t *testing.T) {
require.Equal(t, "r1", mutation.Mentions[1].TargetID)
}
func TestBuildMessageMutationFallsBackToChannelGuildID(t *testing.T) {
now := time.Now().UTC()
mutation, err := buildMessageMutation(context.Background(), &discordgo.Message{
ID: "m1",
ChannelID: "c1",
Content: "missing guild id from channel history",
Timestamp: now,
Author: &discordgo.User{ID: "u1", Username: "peter"},
Attachments: []*discordgo.MessageAttachment{{
ID: "a1",
Filename: "trace.txt",
}},
Mentions: []*discordgo.User{{ID: "u2", Username: "shadow"}},
MentionRoles: []string{
"r1",
},
}, "maintainers", "g1", false, false)
require.NoError(t, err)
require.Equal(t, "g1", mutation.Record.GuildID)
require.Equal(t, "g1", mutation.Attachments[0].GuildID)
require.Equal(t, "g1", mutation.Mentions[0].GuildID)
require.Equal(t, "g1", mutation.Mentions[1].GuildID)
}
func TestShouldFetchAttachmentText(t *testing.T) {
t.Parallel()
@ -93,7 +117,7 @@ func TestBuildMessageMutationSkipsBinaryResponseEvenWhenAttachmentLooksTextual(t
Filename: "trace.txt",
URL: server.URL,
}},
}, "maintainers", false, true)
}, "maintainers", "", false, true)
require.NoError(t, err)
require.Len(t, mutation.Attachments, 1)
require.Empty(t, mutation.Attachments[0].TextContent)
@ -127,7 +151,7 @@ func TestBuildMessageMutationSkipsOversizedAttachmentResponses(t *testing.T) {
ContentType: "text/plain",
URL: server.URL,
}},
}, "maintainers", false, true)
}, "maintainers", "", false, true)
require.NoError(t, err)
require.Len(t, mutation.Attachments, 1)
require.Empty(t, mutation.Attachments[0].TextContent)
@ -159,7 +183,7 @@ func TestBuildMessageMutationRespectsAttachmentTextOptOut(t *testing.T) {
ContentType: "text/plain",
URL: server.URL,
}},
}, "maintainers", false, false)
}, "maintainers", "", false, false)
require.NoError(t, err)
require.Len(t, mutation.Attachments, 1)
require.Empty(t, mutation.Attachments[0].TextContent)

View File

@ -33,7 +33,7 @@ func TestBuildMessageMutationsTracksNewest(t *testing.T) {
Timestamp: now,
Author: &discordgo.User{ID: "u1", Username: "user"},
},
}, "general", true, true)
}, "general", "", true, true)
require.NoError(t, err)
require.Len(t, mutations, 2)

View File

@ -298,7 +298,7 @@ func (s *Syncer) syncFullChannelHistory(ctx context.Context, channel *discordgo.
messageCount := 0
newest := state.Latest
if state.Latest != "" {
count, latest, err := s.syncForwardPages(ctx, channel, state.Latest, channel.Name, embeddings, progress)
count, latest, err := s.syncForwardPages(ctx, channel, state.Latest, embeddings, progress)
messageCount += count
if err != nil {
return messageCount, err
@ -332,7 +332,7 @@ func (s *Syncer) syncIncrementalChannelHistory(ctx context.Context, channel *dis
if state.Latest == "" {
return s.bootstrapChannelHistory(ctx, channel, embeddings, since, progress)
}
count, newest, err := s.syncForwardPages(ctx, channel, state.Latest, channel.Name, embeddings, progress)
count, newest, err := s.syncForwardPages(ctx, channel, state.Latest, embeddings, progress)
if err != nil {
return count, err
}
@ -358,7 +358,7 @@ func (s *Syncer) bootstrapChannelHistory(ctx context.Context, channel *discordgo
break
}
eligible, reachedSince := filterMessagesSince(page, since)
pageNewest, err := s.persistMessagePage(ctx, eligible, channel.Name, embeddings)
pageNewest, err := s.persistMessagePage(ctx, eligible, channel.Name, channel.GuildID, embeddings)
if err != nil {
return messageCount, err
}
@ -392,7 +392,7 @@ func (s *Syncer) bootstrapChannelHistory(ctx context.Context, channel *discordgo
return messageCount, nil
}
func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channel, after, channelName string, embeddings bool, progress *messageSyncProgress) (int, string, error) {
func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channel, after string, embeddings bool, progress *messageSyncProgress) (int, string, error) {
messageCount := 0
newest := after
for {
@ -403,7 +403,7 @@ func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channe
if len(page) == 0 {
break
}
pageNewest, err := s.persistMessagePage(ctx, page, channelName, embeddings)
pageNewest, err := s.persistMessagePage(ctx, page, channel.Name, channel.GuildID, embeddings)
if err != nil {
return messageCount, newest, err
}
@ -436,7 +436,7 @@ func (s *Syncer) syncBackfillPages(ctx context.Context, channel *discordgo.Chann
break
}
eligible, reachedSince := filterMessagesSince(page, since)
pageNewest, err := s.persistMessagePage(ctx, eligible, channelName, embeddings)
pageNewest, err := s.persistMessagePage(ctx, eligible, channelName, channel.GuildID, embeddings)
if err != nil {
return messageCount, newest, err
}
@ -471,11 +471,11 @@ func (s *Syncer) syncBackfillPages(ctx context.Context, channel *discordgo.Chann
return messageCount, newest, nil
}
func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.Message, channelName string, embeddings bool) (string, error) {
func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.Message, channelName string, fallbackGuildID string, embeddings bool) (string, error) {
if len(messages) == 0 {
return "", nil
}
mutations, newest, err := buildMessageMutations(ctx, messages, channelName, embeddings, s.attachmentTextEnabled)
mutations, newest, err := buildMessageMutations(ctx, messages, channelName, fallbackGuildID, embeddings, s.attachmentTextEnabled)
if err != nil {
return "", err
}
@ -485,11 +485,11 @@ func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.M
return newest, nil
}
func buildMessageMutations(ctx context.Context, messages []*discordgo.Message, channelName string, embeddings bool, attachmentText bool) ([]store.MessageMutation, string, error) {
func buildMessageMutations(ctx context.Context, messages []*discordgo.Message, channelName string, fallbackGuildID string, embeddings bool, attachmentText bool) ([]store.MessageMutation, string, error) {
mutations := make([]store.MessageMutation, 0, len(messages))
newest := ""
for _, message := range messages {
mutation, err := buildMessageMutation(ctx, message, channelName, embeddings, attachmentText)
mutation, err := buildMessageMutation(ctx, message, channelName, fallbackGuildID, embeddings, attachmentText)
if err != nil {
return nil, "", err
}

View File

@ -31,7 +31,14 @@ func toMemberRecord(guildID string, member *discordgo.Member) store.MemberRecord
}
}
func toMessageRecord(message *discordgo.Message, channelName, normalizedContent string) store.MessageRecord {
func effectiveMessageGuildID(message *discordgo.Message, fallbackGuildID string) string {
if message != nil && strings.TrimSpace(message.GuildID) != "" {
return message.GuildID
}
return strings.TrimSpace(fallbackGuildID)
}
func toMessageRecord(message *discordgo.Message, channelName, guildID, normalizedContent string) store.MessageRecord {
raw := marshalJSONString(message, "{}")
authorID := ""
authorName := ""
@ -52,7 +59,7 @@ func toMessageRecord(message *discordgo.Message, channelName, normalizedContent
}
return store.MessageRecord{
ID: message.ID,
GuildID: message.GuildID,
GuildID: guildID,
ChannelID: message.ChannelID,
ChannelName: channelName,
AuthorID: authorID,

View File

@ -49,7 +49,7 @@ func (t *tailHandler) OnMessageCreate(ctx context.Context, msg *discordgo.Messag
if !t.allowGuild(msg.GuildID) {
return nil
}
mutation, err := buildMessageMutation(ctx, msg, "", false, t.attachmentTextEnabled)
mutation, err := buildMessageMutation(ctx, msg, "", "", false, t.attachmentTextEnabled)
if err != nil {
return err
}
@ -69,7 +69,7 @@ func (t *tailHandler) OnMessageUpdate(ctx context.Context, msg *discordgo.Messag
if !t.allowGuild(msg.GuildID) {
return nil
}
mutation, err := buildMessageMutation(ctx, msg, "", false, t.attachmentTextEnabled)
mutation, err := buildMessageMutation(ctx, msg, "", "", false, t.attachmentTextEnabled)
if err != nil {
return err
}