fix(tui): hydrate discord reply context

This commit is contained in:
Vincent Koc 2026-05-04 01:34:09 -07:00
parent c8118d9dcc
commit 18d4aba76a
No known key found for this signature in database
3 changed files with 212 additions and 1 deletions

View File

@ -68,7 +68,7 @@ func (r *runtime) runTUI(args []string) error {
})
}
loadRows := func() ([]tui.Row, error) {
rows, err := r.store.ListMessages(r.ctx, store.MessageListOptions{
rows, err := r.store.ListMessagesWithThreadContext(r.ctx, store.MessageListOptions{
GuildIDs: guildIDs,
Channel: *channel,
Author: *author,

View File

@ -183,6 +183,174 @@ func (s *Store) ListMessages(ctx context.Context, opts MessageListOptions) ([]Me
return out, s.resolveMessageDisplayMentions(ctx, out)
}
func (s *Store) ListMessagesWithThreadContext(ctx context.Context, opts MessageListOptions) ([]MessageRow, error) {
rows, err := s.ListMessages(ctx, opts)
if err != nil {
return nil, err
}
return s.hydrateMessageThreadContext(ctx, rows, opts.Limit+opts.Last)
}
func (s *Store) hydrateMessageThreadContext(ctx context.Context, rows []MessageRow, limit int) ([]MessageRow, error) {
if len(rows) == 0 {
return rows, nil
}
type threadRef struct {
guildID string
channelID string
rootID string
}
refs := make([]threadRef, 0, len(rows))
seenRefs := map[string]struct{}{}
for _, row := range rows {
rootID := strings.TrimSpace(row.ReplyToMessage)
if rootID == "" {
rootID = strings.TrimSpace(row.MessageID)
}
if rootID == "" || strings.TrimSpace(row.GuildID) == "" || strings.TrimSpace(row.ChannelID) == "" {
continue
}
key := row.GuildID + "\x00" + row.ChannelID + "\x00" + rootID
if _, ok := seenRefs[key]; ok {
continue
}
seenRefs[key] = struct{}{}
refs = append(refs, threadRef{guildID: row.GuildID, channelID: row.ChannelID, rootID: rootID})
}
if len(refs) == 0 {
return rows, nil
}
clauses := make([]string, 0, len(refs))
args := make([]any, 0, len(refs)*4+1)
for _, ref := range refs {
clauses = append(clauses, `(m.guild_id = ? and m.channel_id = ? and (m.id = ? or m.reply_to_message_id = ?))`)
args = append(args, ref.guildID, ref.channelID, ref.rootID, ref.rootID)
}
contextLimit := limit * 5
if contextLimit < len(rows) {
contextLimit = len(rows)
}
if contextLimit < 200 {
contextLimit = 200
}
if contextLimit > 2000 {
contextLimit = 2000
}
query := `
select
m.id,
m.guild_id,
coalesce(g.name, ''),
m.channel_id,
coalesce(c.name, ''),
coalesce(m.author_id, ''),
coalesce(
nullif(mem.display_name, ''),
nullif(mem.nick, ''),
nullif(mem.global_name, ''),
nullif(mem.username, ''),
nullif(json_extract(m.raw_json, '$.author.global_name'), ''),
nullif(json_extract(m.raw_json, '$.author.username'), ''),
''
),
case
when trim(coalesce(m.content, '')) <> '' then m.content
else m.normalized_content
end,
m.created_at,
coalesce(m.reply_to_message_id, ''),
coalesce(json_extract(m.raw_json, '$.source'), ''),
m.has_attachments,
coalesce((select group_concat(a.filename, ', ') from message_attachments a where a.message_id = m.id), ''),
coalesce((select group_concat(a.text_content, char(10)) from message_attachments a where a.message_id = m.id and trim(a.text_content) <> ''), ''),
m.pinned
from messages m
left join guilds g on g.id = m.guild_id
left join channels c on c.id = m.channel_id
left join members mem on mem.guild_id = m.guild_id and mem.user_id = m.author_id
where ` + strings.Join(clauses, " or ") + `
order by m.created_at asc, m.id asc
limit ?`
args = append(args, contextLimit)
contextRows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = contextRows.Close() }()
extra, err := scanMessageRows(contextRows)
if err != nil {
return nil, err
}
if err := s.resolveMessageDisplayMentions(ctx, extra); err != nil {
return nil, err
}
return mergeMessageRows(rows, extra), nil
}
func scanMessageRows(rows rowScanner) ([]MessageRow, error) {
var out []MessageRow
for rows.Next() {
var row MessageRow
var created string
var hasAttachments int
var pinned int
if err := rows.Scan(
&row.MessageID,
&row.GuildID,
&row.GuildName,
&row.ChannelID,
&row.ChannelName,
&row.AuthorID,
&row.AuthorName,
&row.Content,
&created,
&row.ReplyToMessage,
&row.Source,
&hasAttachments,
&row.AttachmentNames,
&row.AttachmentText,
&pinned,
); err != nil {
return nil, err
}
row.CreatedAt = parseTime(created)
row.HasAttachments = hasAttachments == 1
row.Pinned = pinned == 1
row.DisplayContent = row.Content
out = append(out, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
type rowScanner interface {
Next() bool
Scan(dest ...any) error
Err() error
}
func mergeMessageRows(primary, extra []MessageRow) []MessageRow {
out := make([]MessageRow, 0, len(primary)+len(extra))
seen := map[string]struct{}{}
appendRow := func(row MessageRow) {
key := row.GuildID + "\x00" + row.ChannelID + "\x00" + row.MessageID
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
out = append(out, row)
}
for _, row := range primary {
appendRow(row)
}
for _, row := range extra {
appendRow(row)
}
return out
}
func normalizeChannelFilter(raw string) string {
return strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(raw), "#"))
}

View File

@ -1669,6 +1669,49 @@ func TestListMessagesFiltersAndLimit(t *testing.T) {
require.Equal(t, "m4", rows[1].MessageID)
}
func TestListMessagesWithThreadContextHydratesReplyRoot(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: "root",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
MessageType: 0,
CreatedAt: "2026-03-01T10:00:00Z",
Content: "root message",
NormalizedContent: "root message",
RawJSON: `{}`,
}))
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: "reply",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u2",
MessageType: 0,
CreatedAt: "2026-03-02T10:00:00Z",
Content: "reply message",
NormalizedContent: "reply message",
ReplyToMessageID: "root",
RawJSON: `{}`,
}))
rows, err := s.ListMessagesWithThreadContext(ctx, MessageListOptions{Last: 1})
require.NoError(t, err)
require.Len(t, rows, 2)
require.Equal(t, "reply", rows[0].MessageID)
require.Equal(t, "root", rows[1].MessageID)
}
func TestNormalizeFTSQueryEdgeCases(t *testing.T) {
t.Parallel()