discrawl/internal/store/store_write_test.go
2026-05-08 09:58:34 +01:00

775 lines
25 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/openclaw/crawlkit/embed"
"github.com/stretchr/testify/require"
)
func TestUpsertMessagesBatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
now := time.Now().UTC().Format(time.RFC3339Nano)
require.NoError(t, s.UpsertMessages(ctx, []MessageMutation{
{
Record: MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "one",
NormalizedContent: "one",
RawJSON: `{"id":"m1"}`,
},
EventType: "upsert",
PayloadJSON: `{"id":"m1"}`,
Options: WriteOptions{
AppendEvent: true,
},
},
{
Record: MessageRecord{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "two",
NormalizedContent: "two",
RawJSON: `{"id":"m2"}`,
},
EventType: "upsert",
PayloadJSON: `{"id":"m2"}`,
Options: WriteOptions{
AppendEvent: true,
},
},
}))
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from messages")
require.NoError(t, err)
require.Equal(t, "2", rows[0][0])
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_events")
require.NoError(t, err)
require.Equal(t, "2", rows[0][0])
}
func TestUpsertMessagesHonorsCanceledContext(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
canceled, cancel := context.WithCancel(ctx)
cancel()
err = s.UpsertMessages(canceled, []MessageMutation{{
Record: MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "one",
NormalizedContent: "one",
RawJSON: `{"id":"m1"}`,
},
}})
require.ErrorIs(t, err, context.Canceled)
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from messages")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestUpsertMessagesSkipsEventsAndEmbeddingsByDefault(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
now := time.Now().UTC().Format(time.RFC3339Nano)
require.NoError(t, s.UpsertMessages(ctx, []MessageMutation{{
Record: MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "one",
NormalizedContent: "one",
RawJSON: `{"id":"m1"}`,
},
EventType: "upsert",
PayloadJSON: `{"id":"m1"}`,
}}))
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_events")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from embedding_jobs")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestUpsertMessageWithEmbeddingsQueuesJob(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from embedding_jobs")
require.NoError(t, err)
require.Equal(t, "1", rows[0][0])
}
func TestUpsertMessageWithEmbeddingsQueuesExistingMessageWithoutJob(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessage(ctx, record))
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"pending", "0"}}, rows)
}
func TestDrainEmbeddingJobsStoresVectorsAndSkipsEmptyInput(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
now := time.Now().UTC().Format(time.RFC3339Nano)
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "abcdef世界",
NormalizedContent: "abcdef世界",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "",
NormalizedContent: " ",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
provider := &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{
Vectors: [][]float32{{1.25, 2.5}},
}},
}
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 2,
MaxInputChars: 7,
})
require.NoError(t, err)
require.Equal(t, 2, stats.Processed)
require.Equal(t, 1, stats.Succeeded)
require.Equal(t, 1, stats.Skipped)
require.Equal(t, 0, stats.RemainingBacklog)
require.Equal(t, [][]string{{"abcdef世"}}, provider.inputs)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, provider, model, input_version, dimensions from message_embeddings")
require.NoError(t, err)
require.Equal(t, [][]string{{"m1", "ollama", "nomic-embed-text", EmbeddingInputVersion, "2"}}, rows)
var blob []byte
require.NoError(t, s.DB().QueryRowContext(ctx, `select embedding_blob from message_embeddings where message_id = 'm1'`).Scan(&blob))
vector, err := DecodeEmbeddingVector(blob)
require.NoError(t, err)
require.Equal(t, []float32{1.25, 2.5}, vector)
_, rows, err = s.ReadOnlyQuery(ctx, "select message_id, state, provider, model, input_version from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
{"m2", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
}, rows)
}
func TestUpsertMessageWithEmbeddingsDoesNotRequeueUnchangedDoneJob(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
require.Equal(t, 1, stats.Succeeded)
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"done", "0", ""}}, rows)
backlog, err := s.EmbeddingBacklog(ctx)
require.NoError(t, err)
require.Equal(t, 0, backlog)
record.NormalizedContent = "hello updated"
record.Content = "hello updated"
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err = s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"pending", "0", ""}}, rows)
}
func TestDrainEmbeddingJobsFailsWholeBatchOnDimensionMismatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{
Dimensions: 3,
Vectors: [][]float32{{1, 2}},
}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
require.Equal(t, 1, stats.Failed)
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, "pending", rows[0][0])
require.Equal(t, "1", rows[0][1])
require.Contains(t, rows[0][2], "dimensions mismatch")
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestDrainEmbeddingJobsMarksFailedAfterMaxAttempts(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
_, err = s.DB().ExecContext(ctx, `update embedding_jobs set attempts = 2 where message_id = 'm1'`)
require.NoError(t, err)
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{err: errors.New("provider down")}, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
})
require.NoError(t, err)
require.Equal(t, 1, stats.Failed)
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"failed", "3", "provider down"}}, rows)
}
func TestDrainEmbeddingJobsStopsOnRateLimit(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
}
provider := &fakeEmbeddingProvider{err: &embed.HTTPError{StatusCode: 429, Body: "slow down"}}
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 1,
})
require.NoError(t, err)
require.True(t, stats.RateLimited)
require.Equal(t, 0, stats.Processed)
require.Equal(t, 0, stats.Failed)
require.Equal(t, 1, stats.Requeued)
require.Equal(t, 2, stats.RemainingBacklog)
require.Len(t, provider.inputs, 1)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, last_error, coalesce(locked_at, '') from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "pending", "0", "embedding request failed with HTTP 429: slow down", ""},
{"m2", "pending", "0", "", ""},
}, rows)
}
func TestDrainEmbeddingJobsDeletesStaleVectorsForEmptyContent(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, err = s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
record.Content = ""
record.NormalizedContent = " "
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{}, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 1,
})
require.NoError(t, err)
require.Equal(t, 1, stats.Processed)
require.Equal(t, 1, stats.Skipped)
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
_, rows, err = s.ReadOnlyQuery(ctx, "select state, provider, model from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"done", "ollama", "nomic-embed-text"}}, rows)
}
func TestPendingEmbeddingJobsSkipsFreshLocks(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
now := time.Now().UTC()
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
jobs, err := s.pendingEmbeddingJobs(ctx, 10, staleBefore)
require.NoError(t, err)
require.Len(t, jobs, 1)
claimed, err := s.lockEmbeddingJobs(ctx, jobs, now.Format(timeLayout), staleBefore)
require.NoError(t, err)
require.Len(t, claimed, 1)
jobs, err = s.pendingEmbeddingJobs(ctx, 10, staleBefore)
require.NoError(t, err)
require.Empty(t, jobs)
claimed, err = s.lockEmbeddingJobs(ctx, claimed, now.Add(time.Minute).Format(timeLayout), staleBefore)
require.NoError(t, err)
require.Empty(t, claimed)
}
func TestRequeueAllEmbeddingJobsUsesCurrentIdentity(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}))
}
_, err = s.DB().ExecContext(ctx, `
insert into embedding_jobs(message_id, state, attempts, provider, model, input_version, last_error, updated_at)
values('m1', 'failed', 3, 'old', 'old-model', 'old-input', 'old error', ?)
`, time.Now().UTC().Format(timeLayout))
require.NoError(t, err)
requeued, err := s.RequeueAllEmbeddingJobs(ctx, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
})
require.NoError(t, err)
require.Equal(t, 2, requeued)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, provider, model, input_version, last_error from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
{"m2", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
}, rows)
}
func TestEmbeddingHelpersAndIdentityResetBranches(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
now := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
opts := normalizeEmbeddingDrainOptions(EmbeddingDrainOptions{
Provider: " OLLAMA ",
Model: " model ",
Limit: 1,
BatchSize: 5,
MaxInputChars: 3,
Now: func() time.Time { return now },
})
require.Equal(t, "ollama", opts.Provider)
require.Equal(t, "model", opts.Model)
require.Equal(t, 1, opts.BatchSize)
require.Equal(t, EmbeddingInputVersion, opts.InputVersion)
stats, err := s.DrainEmbeddingJobs(ctx, nil, opts)
require.ErrorContains(t, err, "embedding provider is nil")
require.Equal(t, "ollama", stats.Provider)
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now.Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
_, err = s.DB().ExecContext(ctx, `
update embedding_jobs
set provider = 'old', model = 'old-model', input_version = 'old-version', attempts = 2, last_error = 'bad', locked_at = 'locked'
where message_id = 'm1'
`)
require.NoError(t, err)
require.NoError(t, s.resetEmbeddingJobIdentity(ctx, "m1", opts, true))
_, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, attempts, last_error, coalesce(locked_at, '') from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"ollama", "model", EmbeddingInputVersion, "0", "", ""}}, rows)
_, err = s.DB().ExecContext(ctx, `update embedding_jobs set attempts = 2, locked_at = 'locked' where message_id = 'm1'`)
require.NoError(t, err)
require.NoError(t, s.resetEmbeddingJobIdentity(ctx, "m1", opts, false))
_, rows, err = s.ReadOnlyQuery(ctx, "select attempts, coalesce(locked_at, '') from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"2", ""}}, rows)
require.True(t, sameEmbeddingIdentity(embeddingJob{Provider: "ollama", Model: "model", InputVersion: EmbeddingInputVersion}, opts))
require.True(t, emptyEmbeddingIdentity(embeddingJob{}))
_, err = validateEmbeddingBatch(embed.EmbeddingBatch{Vectors: [][]float32{{1}, {2}}}, 1)
require.ErrorContains(t, err, "returned 2 vectors")
_, err = validateEmbeddingBatch(embed.EmbeddingBatch{Vectors: [][]float32{{}}}, 1)
require.ErrorContains(t, err, "empty vector")
require.Empty(t, trimStoredError(nil))
require.Len(t, []rune(trimStoredError(errors.New(strings.Repeat("x", maxStoredErrorChars+10)))), maxStoredErrorChars)
require.Equal(t, "abcdef", capRunes("abcdef", 0))
require.Equal(t, "abc", capRunes("abcdef", 3))
_, err = DecodeEmbeddingVector([]byte{1, 2, 3})
require.ErrorContains(t, err, "not a float32 multiple")
}
func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
var wg sync.WaitGroup
errCh := make(chan error, 8)
for i := range 8 {
wg.Add(1)
go func(i int) {
defer wg.Done()
errCh <- s.UpsertMessage(ctx, MessageRecord{
ID: stringify(i),
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
})
}(i)
}
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from messages")
require.NoError(t, err)
require.Equal(t, "8", rows[0][0])
}
type fakeEmbeddingProvider struct {
batches []embed.EmbeddingBatch
err error
inputs [][]string
}
func (f *fakeEmbeddingProvider) Embed(_ context.Context, inputs []string) (embed.EmbeddingBatch, error) {
copied := append([]string(nil), inputs...)
f.inputs = append(f.inputs, copied)
if f.err != nil {
return embed.EmbeddingBatch{}, f.err
}
if len(f.batches) == 0 {
return embed.EmbeddingBatch{}, nil
}
batch := f.batches[0]
f.batches = f.batches[1:]
return batch, nil
}
func TestMessageFTSUsesSnowflakeRowID(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "1469950701764350208",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
AuthorName: "Peter",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "first body",
NormalizedContent: "first body",
RawJSON: `{"author":{"username":"Peter"}}`,
}
require.NoError(t, s.UpsertMessage(ctx, record))
record.Content = "second body"
record.NormalizedContent = "second body"
require.NoError(t, s.UpsertMessage(ctx, record))
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*), min(rowid), max(rowid), min(content) from message_fts")
require.NoError(t, err)
require.Equal(t, []string{"1", "1469950701764350208", "1469950701764350208", "second body"}, rows[0])
}
func TestMemberFTSUpdatesOnUpsert(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MemberRecord{
GuildID: "g1",
UserID: "u1",
Username: "peter",
DisplayName: "Peter",
RoleIDsJSON: `[]`,
RawJSON: `{"bio":"Maintainer","github":"steipete"}`,
Discriminator: "0",
}
require.NoError(t, s.UpsertMember(ctx, record))
record.RawJSON = `{"bio":"Updated bio","github":"steipete","website":"https://steipete.me"}`
require.NoError(t, s.UpsertMember(ctx, record))
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*), min(profile_text) from member_fts")
require.NoError(t, err)
require.Equal(t, []string{"1", "Updated bio steipete https://steipete.me"}, rows[0])
}
func TestOpenRebuildsLegacyFTSRowIDs(t *testing.T) {
t.Parallel()
ctx := context.Background()
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
s, err := Open(ctx, dbPath)
require.NoError(t, err)
messageID := "1469950701764350208"
channelID := "c1"
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: channelID, GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: messageID,
GuildID: "g1",
ChannelID: channelID,
ChannelName: "general",
AuthorID: "u1",
AuthorName: "Peter",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "panic database is locked",
NormalizedContent: "panic database is locked",
RawJSON: `{"author":{"username":"Peter"}}`,
}))
require.NoError(t, s.Close())
sqlDB, err := sql.Open("sqlite", dbPath)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `delete from message_fts`)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
insert into message_fts(message_id, guild_id, channel_id, author_id, author_name, channel_name, content)
values(?, ?, ?, ?, ?, ?, ?)
`, messageID, "g1", channelID, "u1", "Peter", "general", "panic database is locked")
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `delete from sync_state where scope = 'schema:message_fts_rowid_version'`)
require.NoError(t, err)
require.NoError(t, sqlDB.Close())
s, err = Open(ctx, dbPath)
require.NoError(t, err)
defer func() { _ = s.Close() }()
_, rows, err := s.ReadOnlyQuery(ctx, "select rowid, message_id from message_fts")
require.NoError(t, err)
require.Equal(t, []string{messageID, messageID}, rows[0])
results, err := s.SearchMessages(ctx, SearchOptions{Query: "panic", Limit: 10})
require.NoError(t, err)
require.Len(t, results, 1)
require.Equal(t, messageID, results[0].MessageID)
}