fix(share): harden git snapshot imports (#51)
* fix(share): harden git snapshot imports * docs: update changelog for share import hardening
This commit is contained in:
parent
59f42cb0ab
commit
b387ed2d6f
@ -6,6 +6,8 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
|
||||
### Fixes
|
||||
|
||||
- Git snapshot imports now recover from corrupt local FTS tables by dropping and rebuilding search indexes, and repair missing guild IDs from channel metadata so shared archive reports stay fresh.
|
||||
- Channel-history sync now falls back to the channel guild when Discord omits `message.guild_id`, keeping messages, attachments, mentions, and FTS rows correctly scoped.
|
||||
- Repeated `sync --source wiretap` runs now skip unchanged Discord Desktop cache files and report unchanged file counts, making steady-state local-cache refreshes much faster.
|
||||
- `sync --full --skip-members` now also skips member crawls when resuming incomplete stored channels, so backfills do not unexpectedly refresh the full guild member list.
|
||||
|
||||
|
||||
@ -241,8 +241,8 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
}
|
||||
}()
|
||||
for _, table := range []string{"message_fts", "member_fts"} {
|
||||
if _, err := tx.ExecContext(ctx, "delete from "+table); err != nil {
|
||||
return Manifest{}, fmt.Errorf("clear %s: %w", table, err)
|
||||
if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil {
|
||||
return Manifest{}, fmt.Errorf("drop %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
for i := len(SnapshotTables) - 1; i >= 0; i-- {
|
||||
@ -257,6 +257,9 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error)
|
||||
return Manifest{}, err
|
||||
}
|
||||
}
|
||||
if err := repairImportedGuildIDs(ctx, tx); err != nil {
|
||||
return Manifest{}, err
|
||||
}
|
||||
if opts.IncludeEmbeddings {
|
||||
if err := importEmbeddings(ctx, tx, opts, manifest.Embeddings); err != nil {
|
||||
return Manifest{}, err
|
||||
@ -618,6 +621,67 @@ func importTableFile(ctx context.Context, stmt *sql.Stmt, repoPath string, table
|
||||
return nil
|
||||
}
|
||||
|
||||
func repairImportedGuildIDs(ctx context.Context, tx *sql.Tx) error {
|
||||
repairs := []struct {
|
||||
table string
|
||||
query string
|
||||
}{
|
||||
{"messages", `
|
||||
update messages
|
||||
set guild_id = (
|
||||
select c.guild_id
|
||||
from channels c
|
||||
where c.id = messages.channel_id
|
||||
)
|
||||
where coalesce(guild_id, '') = ''
|
||||
and exists (
|
||||
select 1
|
||||
from channels c
|
||||
where c.id = messages.channel_id
|
||||
and coalesce(c.guild_id, '') != ''
|
||||
)`},
|
||||
{"message_attachments", `
|
||||
update message_attachments
|
||||
set guild_id = coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = message_attachments.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = message_attachments.channel_id)
|
||||
)
|
||||
where coalesce(guild_id, '') = ''
|
||||
and coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = message_attachments.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = message_attachments.channel_id)
|
||||
) is not null`},
|
||||
{"message_events", `
|
||||
update message_events
|
||||
set guild_id = coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = message_events.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = message_events.channel_id)
|
||||
)
|
||||
where coalesce(guild_id, '') = ''
|
||||
and coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = message_events.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = message_events.channel_id)
|
||||
) is not null`},
|
||||
{"mention_events", `
|
||||
update mention_events
|
||||
set guild_id = coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = mention_events.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = mention_events.channel_id)
|
||||
)
|
||||
where coalesce(guild_id, '') = ''
|
||||
and coalesce(
|
||||
nullif((select m.guild_id from messages m where m.id = mention_events.message_id), ''),
|
||||
(select c.guild_id from channels c where c.id = mention_events.channel_id)
|
||||
) is not null`},
|
||||
}
|
||||
for _, repair := range repairs {
|
||||
if _, err := tx.ExecContext(ctx, repair.query); err != nil {
|
||||
return fmt.Errorf("repair imported %s guild ids: %w", repair.table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func importColumns(table TableManifest) []string {
|
||||
if table.Name != "message_events" && table.Name != "mention_events" {
|
||||
return table.Columns
|
||||
|
||||
@ -63,6 +63,41 @@ func TestExportImportRoundTrip(t *testing.T) {
|
||||
require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt)
|
||||
}
|
||||
|
||||
func TestImportRepairsBlankMessageGuildIDs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
defer func() { _ = src.Close() }()
|
||||
_, err := src.DB().ExecContext(ctx, `update messages set guild_id = '' where id = 'm1'`)
|
||||
require.NoError(t, err)
|
||||
_, err = src.DB().ExecContext(ctx, `update message_events set guild_id = '' where message_id = 'm1'`)
|
||||
require.NoError(t, err)
|
||||
_, err = src.DB().ExecContext(ctx, `update mention_events set guild_id = '' where message_id = 'm1'`)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := filepath.Join(t.TempDir(), "share")
|
||||
_, err = Export(ctx, src, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, snapshotTableText(t, repo, tableEntry(t, mustReadManifest(t, repo), "messages")), `"guild_id":""`)
|
||||
|
||||
dst, err := store.Open(ctx, filepath.Join(t.TempDir(), "dst.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = dst.Close() }()
|
||||
_, err = Import(ctx, dst, Options{RepoPath: repo, Branch: "main"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var guildID string
|
||||
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from messages where id = 'm1'`).Scan(&guildID))
|
||||
require.Equal(t, "g1", guildID)
|
||||
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from message_events where message_id = 'm1'`).Scan(&guildID))
|
||||
require.Equal(t, "g1", guildID)
|
||||
require.NoError(t, dst.DB().QueryRowContext(ctx, `select guild_id from mention_events where message_id = 'm1'`).Scan(&guildID))
|
||||
require.Equal(t, "g1", guildID)
|
||||
results, err := dst.SearchMessages(ctx, store.SearchOptions{Query: "launch", GuildIDs: []string{"g1"}, Limit: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "g1", results[0].GuildID)
|
||||
}
|
||||
|
||||
func TestSnapshotExcludesLocalEmbeddingState(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := seedStore(t, filepath.Join(t.TempDir(), "src.db"))
|
||||
@ -860,6 +895,13 @@ func configureGitUser(t *testing.T, repo string) {
|
||||
require.NoError(t, exec.Command("git", "-C", repo, "config", "user.email", "discrawl@example.com").Run())
|
||||
}
|
||||
|
||||
func mustReadManifest(t *testing.T, repo string) Manifest {
|
||||
t.Helper()
|
||||
manifest, err := ReadManifest(repo)
|
||||
require.NoError(t, err)
|
||||
return manifest
|
||||
}
|
||||
|
||||
func tableEntry(t *testing.T, manifest Manifest, name string) TableManifest {
|
||||
t.Helper()
|
||||
for _, table := range manifest.Tables {
|
||||
|
||||
@ -27,15 +27,17 @@ func buildMessageMutation(
|
||||
ctx context.Context,
|
||||
message *discordgo.Message,
|
||||
channelName string,
|
||||
fallbackGuildID string,
|
||||
embeddings bool,
|
||||
attachmentText bool,
|
||||
) (store.MessageMutation, error) {
|
||||
attachments, attachmentParts, err := extractAttachments(ctx, message, attachmentText)
|
||||
guildID := effectiveMessageGuildID(message, fallbackGuildID)
|
||||
attachments, attachmentParts, err := extractAttachments(ctx, message, guildID, attachmentText)
|
||||
if err != nil {
|
||||
return store.MessageMutation{}, err
|
||||
}
|
||||
normalized := normalizeMessageParts(message, attachmentParts)
|
||||
record := toMessageRecord(message, channelName, normalized)
|
||||
record := toMessageRecord(message, channelName, guildID, normalized)
|
||||
return store.MessageMutation{
|
||||
Record: record,
|
||||
EventType: "upsert",
|
||||
@ -44,11 +46,11 @@ func buildMessageMutation(
|
||||
EnqueueEmbedding: embeddings,
|
||||
},
|
||||
Attachments: attachments,
|
||||
Mentions: extractMentions(message),
|
||||
Mentions: extractMentions(message, guildID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractAttachments(ctx context.Context, message *discordgo.Message, attachmentText bool) ([]store.AttachmentRecord, []string, error) {
|
||||
func extractAttachments(ctx context.Context, message *discordgo.Message, guildID string, attachmentText bool) ([]store.AttachmentRecord, []string, error) {
|
||||
if message == nil || len(message.Attachments) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@ -61,7 +63,7 @@ func extractAttachments(ctx context.Context, message *discordgo.Message, attachm
|
||||
record := store.AttachmentRecord{
|
||||
AttachmentID: attachment.ID,
|
||||
MessageID: message.ID,
|
||||
GuildID: message.GuildID,
|
||||
GuildID: guildID,
|
||||
ChannelID: message.ChannelID,
|
||||
Filename: attachment.Filename,
|
||||
ContentType: attachment.ContentType,
|
||||
@ -90,7 +92,7 @@ func extractAttachments(ctx context.Context, message *discordgo.Message, attachm
|
||||
return records, parts, nil
|
||||
}
|
||||
|
||||
func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
|
||||
func extractMentions(message *discordgo.Message, guildID string) []store.MentionEventRecord {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
@ -116,7 +118,7 @@ func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
|
||||
}
|
||||
mentions = append(mentions, store.MentionEventRecord{
|
||||
MessageID: message.ID,
|
||||
GuildID: message.GuildID,
|
||||
GuildID: guildID,
|
||||
ChannelID: message.ChannelID,
|
||||
AuthorID: authorID,
|
||||
TargetType: "user",
|
||||
@ -137,7 +139,7 @@ func extractMentions(message *discordgo.Message) []store.MentionEventRecord {
|
||||
seen[key] = struct{}{}
|
||||
mentions = append(mentions, store.MentionEventRecord{
|
||||
MessageID: message.ID,
|
||||
GuildID: message.GuildID,
|
||||
GuildID: guildID,
|
||||
ChannelID: message.ChannelID,
|
||||
AuthorID: authorID,
|
||||
TargetType: "role",
|
||||
|
||||
@ -44,7 +44,7 @@ func TestBuildMessageMutationIncludesAttachmentTextAndMentions(t *testing.T) {
|
||||
GlobalName: "Shadow",
|
||||
}},
|
||||
MentionRoles: []string{"r1"},
|
||||
}, "maintainers", false, true)
|
||||
}, "maintainers", "", false, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mutation.Attachments, 1)
|
||||
require.Equal(t, "trace.txt", mutation.Attachments[0].Filename)
|
||||
@ -59,6 +59,30 @@ func TestBuildMessageMutationIncludesAttachmentTextAndMentions(t *testing.T) {
|
||||
require.Equal(t, "r1", mutation.Mentions[1].TargetID)
|
||||
}
|
||||
|
||||
func TestBuildMessageMutationFallsBackToChannelGuildID(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
mutation, err := buildMessageMutation(context.Background(), &discordgo.Message{
|
||||
ID: "m1",
|
||||
ChannelID: "c1",
|
||||
Content: "missing guild id from channel history",
|
||||
Timestamp: now,
|
||||
Author: &discordgo.User{ID: "u1", Username: "peter"},
|
||||
Attachments: []*discordgo.MessageAttachment{{
|
||||
ID: "a1",
|
||||
Filename: "trace.txt",
|
||||
}},
|
||||
Mentions: []*discordgo.User{{ID: "u2", Username: "shadow"}},
|
||||
MentionRoles: []string{
|
||||
"r1",
|
||||
},
|
||||
}, "maintainers", "g1", false, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "g1", mutation.Record.GuildID)
|
||||
require.Equal(t, "g1", mutation.Attachments[0].GuildID)
|
||||
require.Equal(t, "g1", mutation.Mentions[0].GuildID)
|
||||
require.Equal(t, "g1", mutation.Mentions[1].GuildID)
|
||||
}
|
||||
|
||||
func TestShouldFetchAttachmentText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -93,7 +117,7 @@ func TestBuildMessageMutationSkipsBinaryResponseEvenWhenAttachmentLooksTextual(t
|
||||
Filename: "trace.txt",
|
||||
URL: server.URL,
|
||||
}},
|
||||
}, "maintainers", false, true)
|
||||
}, "maintainers", "", false, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mutation.Attachments, 1)
|
||||
require.Empty(t, mutation.Attachments[0].TextContent)
|
||||
@ -127,7 +151,7 @@ func TestBuildMessageMutationSkipsOversizedAttachmentResponses(t *testing.T) {
|
||||
ContentType: "text/plain",
|
||||
URL: server.URL,
|
||||
}},
|
||||
}, "maintainers", false, true)
|
||||
}, "maintainers", "", false, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mutation.Attachments, 1)
|
||||
require.Empty(t, mutation.Attachments[0].TextContent)
|
||||
@ -159,7 +183,7 @@ func TestBuildMessageMutationRespectsAttachmentTextOptOut(t *testing.T) {
|
||||
ContentType: "text/plain",
|
||||
URL: server.URL,
|
||||
}},
|
||||
}, "maintainers", false, false)
|
||||
}, "maintainers", "", false, false)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mutation.Attachments, 1)
|
||||
require.Empty(t, mutation.Attachments[0].TextContent)
|
||||
|
||||
@ -33,7 +33,7 @@ func TestBuildMessageMutationsTracksNewest(t *testing.T) {
|
||||
Timestamp: now,
|
||||
Author: &discordgo.User{ID: "u1", Username: "user"},
|
||||
},
|
||||
}, "general", true, true)
|
||||
}, "general", "", true, true)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mutations, 2)
|
||||
|
||||
@ -298,7 +298,7 @@ func (s *Syncer) syncFullChannelHistory(ctx context.Context, channel *discordgo.
|
||||
messageCount := 0
|
||||
newest := state.Latest
|
||||
if state.Latest != "" {
|
||||
count, latest, err := s.syncForwardPages(ctx, channel, state.Latest, channel.Name, embeddings, progress)
|
||||
count, latest, err := s.syncForwardPages(ctx, channel, state.Latest, embeddings, progress)
|
||||
messageCount += count
|
||||
if err != nil {
|
||||
return messageCount, err
|
||||
@ -332,7 +332,7 @@ func (s *Syncer) syncIncrementalChannelHistory(ctx context.Context, channel *dis
|
||||
if state.Latest == "" {
|
||||
return s.bootstrapChannelHistory(ctx, channel, embeddings, since, progress)
|
||||
}
|
||||
count, newest, err := s.syncForwardPages(ctx, channel, state.Latest, channel.Name, embeddings, progress)
|
||||
count, newest, err := s.syncForwardPages(ctx, channel, state.Latest, embeddings, progress)
|
||||
if err != nil {
|
||||
return count, err
|
||||
}
|
||||
@ -358,7 +358,7 @@ func (s *Syncer) bootstrapChannelHistory(ctx context.Context, channel *discordgo
|
||||
break
|
||||
}
|
||||
eligible, reachedSince := filterMessagesSince(page, since)
|
||||
pageNewest, err := s.persistMessagePage(ctx, eligible, channel.Name, embeddings)
|
||||
pageNewest, err := s.persistMessagePage(ctx, eligible, channel.Name, channel.GuildID, embeddings)
|
||||
if err != nil {
|
||||
return messageCount, err
|
||||
}
|
||||
@ -392,7 +392,7 @@ func (s *Syncer) bootstrapChannelHistory(ctx context.Context, channel *discordgo
|
||||
return messageCount, nil
|
||||
}
|
||||
|
||||
func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channel, after, channelName string, embeddings bool, progress *messageSyncProgress) (int, string, error) {
|
||||
func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channel, after string, embeddings bool, progress *messageSyncProgress) (int, string, error) {
|
||||
messageCount := 0
|
||||
newest := after
|
||||
for {
|
||||
@ -403,7 +403,7 @@ func (s *Syncer) syncForwardPages(ctx context.Context, channel *discordgo.Channe
|
||||
if len(page) == 0 {
|
||||
break
|
||||
}
|
||||
pageNewest, err := s.persistMessagePage(ctx, page, channelName, embeddings)
|
||||
pageNewest, err := s.persistMessagePage(ctx, page, channel.Name, channel.GuildID, embeddings)
|
||||
if err != nil {
|
||||
return messageCount, newest, err
|
||||
}
|
||||
@ -436,7 +436,7 @@ func (s *Syncer) syncBackfillPages(ctx context.Context, channel *discordgo.Chann
|
||||
break
|
||||
}
|
||||
eligible, reachedSince := filterMessagesSince(page, since)
|
||||
pageNewest, err := s.persistMessagePage(ctx, eligible, channelName, embeddings)
|
||||
pageNewest, err := s.persistMessagePage(ctx, eligible, channelName, channel.GuildID, embeddings)
|
||||
if err != nil {
|
||||
return messageCount, newest, err
|
||||
}
|
||||
@ -471,11 +471,11 @@ func (s *Syncer) syncBackfillPages(ctx context.Context, channel *discordgo.Chann
|
||||
return messageCount, newest, nil
|
||||
}
|
||||
|
||||
func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.Message, channelName string, embeddings bool) (string, error) {
|
||||
func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.Message, channelName string, fallbackGuildID string, embeddings bool) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
mutations, newest, err := buildMessageMutations(ctx, messages, channelName, embeddings, s.attachmentTextEnabled)
|
||||
mutations, newest, err := buildMessageMutations(ctx, messages, channelName, fallbackGuildID, embeddings, s.attachmentTextEnabled)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -485,11 +485,11 @@ func (s *Syncer) persistMessagePage(ctx context.Context, messages []*discordgo.M
|
||||
return newest, nil
|
||||
}
|
||||
|
||||
func buildMessageMutations(ctx context.Context, messages []*discordgo.Message, channelName string, embeddings bool, attachmentText bool) ([]store.MessageMutation, string, error) {
|
||||
func buildMessageMutations(ctx context.Context, messages []*discordgo.Message, channelName string, fallbackGuildID string, embeddings bool, attachmentText bool) ([]store.MessageMutation, string, error) {
|
||||
mutations := make([]store.MessageMutation, 0, len(messages))
|
||||
newest := ""
|
||||
for _, message := range messages {
|
||||
mutation, err := buildMessageMutation(ctx, message, channelName, embeddings, attachmentText)
|
||||
mutation, err := buildMessageMutation(ctx, message, channelName, fallbackGuildID, embeddings, attachmentText)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@ -31,7 +31,14 @@ func toMemberRecord(guildID string, member *discordgo.Member) store.MemberRecord
|
||||
}
|
||||
}
|
||||
|
||||
func toMessageRecord(message *discordgo.Message, channelName, normalizedContent string) store.MessageRecord {
|
||||
func effectiveMessageGuildID(message *discordgo.Message, fallbackGuildID string) string {
|
||||
if message != nil && strings.TrimSpace(message.GuildID) != "" {
|
||||
return message.GuildID
|
||||
}
|
||||
return strings.TrimSpace(fallbackGuildID)
|
||||
}
|
||||
|
||||
func toMessageRecord(message *discordgo.Message, channelName, guildID, normalizedContent string) store.MessageRecord {
|
||||
raw := marshalJSONString(message, "{}")
|
||||
authorID := ""
|
||||
authorName := ""
|
||||
@ -52,7 +59,7 @@ func toMessageRecord(message *discordgo.Message, channelName, normalizedContent
|
||||
}
|
||||
return store.MessageRecord{
|
||||
ID: message.ID,
|
||||
GuildID: message.GuildID,
|
||||
GuildID: guildID,
|
||||
ChannelID: message.ChannelID,
|
||||
ChannelName: channelName,
|
||||
AuthorID: authorID,
|
||||
|
||||
@ -49,7 +49,7 @@ func (t *tailHandler) OnMessageCreate(ctx context.Context, msg *discordgo.Messag
|
||||
if !t.allowGuild(msg.GuildID) {
|
||||
return nil
|
||||
}
|
||||
mutation, err := buildMessageMutation(ctx, msg, "", false, t.attachmentTextEnabled)
|
||||
mutation, err := buildMessageMutation(ctx, msg, "", "", false, t.attachmentTextEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -69,7 +69,7 @@ func (t *tailHandler) OnMessageUpdate(ctx context.Context, msg *discordgo.Messag
|
||||
if !t.allowGuild(msg.GuildID) {
|
||||
return nil
|
||||
}
|
||||
mutation, err := buildMessageMutation(ctx, msg, "", false, t.attachmentTextEnabled)
|
||||
mutation, err := buildMessageMutation(ctx, msg, "", "", false, t.attachmentTextEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user