feat(embed): add embedding job drain

* feat(embed): add embedding job drain

* fix(embed): migrate legacy jobs and requeue rate limits safely

---------

Co-authored-by: Vincent Koc <vincentkoc@ieee.org>
This commit is contained in:
MrBrain 2026-04-22 13:09:10 +08:00 committed by GitHub
parent 2f07416702
commit ad4c897371
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1449 additions and 3 deletions

View File

@ -8,6 +8,7 @@ All notable changes to `discrawl` will be documented in this file.
- `messages` and `mentions` now use composite read-path indexes so larger archives spend less time sorting/filtering common guild, channel, and author queries
- normalized message text is now sanitized before it reaches SQLite and FTS5, repairing malformed UTF-8 and stripping invisible/control-character noise that can poison search content
- local embedding providers now support OpenAI-compatible endpoints, Ollama, and llama.cpp, and `doctor` can probe the configured provider before you queue vectors
- `embed` now drains the queued embedding backlog in bounded batches, requeues safely on provider throttling, and drops stale stored vectors when messages no longer have embeddable content
## 0.3.0 - 2026-04-21

View File

@ -168,6 +168,61 @@ func (r *runtime) runStatus(args []string) error {
return r.print(status)
}
func (r *runtime) runEmbed(args []string) error {
fs := flag.NewFlagSet("embed", flag.ContinueOnError)
fs.SetOutput(io.Discard)
limit := fs.Int("limit", store.DefaultEmbedLimit(), "")
batchSize := fs.Int("batch-size", r.cfg.Search.Embeddings.BatchSize, "")
rebuild := fs.Bool("rebuild", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
if fs.NArg() != 0 {
return usageErr(fmt.Errorf("embed takes no positional arguments"))
}
if *limit <= 0 {
return usageErr(fmt.Errorf("--limit must be positive"))
}
if *batchSize <= 0 {
return usageErr(fmt.Errorf("--batch-size must be positive"))
}
if !r.cfg.Search.Embeddings.Enabled {
return usageErr(fmt.Errorf("embeddings are disabled in config"))
}
providerFactory := r.newEmbed
if providerFactory == nil {
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
return embed.NewProvider(cfg)
}
}
provider, err := providerFactory(r.cfg.Search.Embeddings)
if err != nil {
return configErr(err)
}
opts := store.EmbeddingDrainOptions{
Provider: r.cfg.Search.Embeddings.Provider,
Model: r.cfg.Search.Embeddings.Model,
InputVersion: store.EmbeddingInputVersion,
Limit: *limit,
BatchSize: *batchSize,
MaxInputChars: r.cfg.Search.Embeddings.MaxInputChars,
Now: r.now,
}
requeued := 0
if *rebuild {
requeued, err = r.store.RequeueAllEmbeddingJobs(r.ctx, opts)
if err != nil {
return err
}
}
stats, err := r.store.DrainEmbeddingJobs(r.ctx, provider, opts)
if err != nil {
return err
}
stats.Requeued = requeued
return r.print(stats)
}
func (r *runtime) runDoctor(args []string) error {
if len(args) != 0 {
return usageErr(fmt.Errorf("doctor takes no arguments"))

View File

@ -13,6 +13,7 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/steipete/discrawl/internal/config"
"github.com/steipete/discrawl/internal/discord"
"github.com/steipete/discrawl/internal/embed"
"github.com/steipete/discrawl/internal/share"
"github.com/steipete/discrawl/internal/store"
"github.com/steipete/discrawl/internal/syncer"
@ -96,6 +97,7 @@ type runtime struct {
openStore func(context.Context, string) (*store.Store, error)
newDiscord func(config.Config) (discordClient, error)
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
now func() time.Time
}
@ -130,6 +132,8 @@ func (r *runtime) dispatch(rest []string) error {
return r.withServicesAuto(hasBoolFlag(rest[1:], "--sync"), true, func() error { return r.runMessages(rest[1:]) })
case "mentions":
return r.withServices(false, func() error { return r.runMentions(rest[1:]) })
case "embed":
return r.withServices(false, func() error { return r.runEmbed(rest[1:]) })
case "sql":
return r.withServices(false, func() error { return r.runSQL(rest[1:]) })
case "members":

View File

@ -3,6 +3,7 @@ package cli
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
@ -489,6 +490,70 @@ func runGit(t *testing.T, dir string, args ...string) {
require.NoError(t, err, string(out))
}
func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/embeddings", r.URL.Path)
var req struct {
Input []string `json:"input"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.Len(t, req.Input, 1)
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1,2]}]}`))
}))
defer server.Close()
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessageWithOptions(ctx, store.MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, store.WriteOptions{EnqueueEmbedding: true}))
}
require.NoError(t, s.Close())
var out bytes.Buffer
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--limit", "1"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "processed=1")
require.Contains(t, out.String(), "succeeded=1")
require.Contains(t, out.String(), "remaining_backlog=1")
require.Contains(t, out.String(), "provider=openai_compatible")
s, err = store.Open(ctx, dbPath)
require.NoError(t, err)
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "1", rows[0][0])
require.NoError(t, s.Close())
out.Reset()
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--rebuild", "--limit", "1"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "processed=1")
require.Contains(t, out.String(), "succeeded=1")
require.Contains(t, out.String(), "remaining_backlog=1")
require.Contains(t, out.String(), "requeued=2")
}
type fakeDiscordClient struct {
guilds []*discordgo.UserGuild
self *discordgo.User

View File

@ -80,6 +80,7 @@ Commands:
search
messages
mentions
embed
sql
members
channels
@ -108,6 +109,21 @@ func printHuman(w io.Writer, value any) error {
v.DBPath, v.GuildCount, v.ChannelCount, v.ThreadCount, v.MessageCount, v.MemberCount, v.EmbeddingBacklog,
formatTime(v.LastSyncAt), formatTime(v.LastTailEventAt))
return err
case store.EmbeddingDrainStats:
_, err := fmt.Fprintf(w, "processed=%d\nsucceeded=%d\nfailed=%d\nskipped=%d\nremaining_backlog=%d\nprovider=%s\nmodel=%s\ninput_version=%s\n",
v.Processed, v.Succeeded, v.Failed, v.Skipped, v.RemainingBacklog, v.Provider, v.Model, v.InputVersion)
if err != nil {
return err
}
if v.Requeued > 0 {
if _, err := fmt.Fprintf(w, "requeued=%d\n", v.Requeued); err != nil {
return err
}
}
if v.RateLimited {
_, err = fmt.Fprintln(w, "rate_limited=true")
}
return err
case []store.SearchResult:
for _, row := range v {
if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s\n\n", row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Content); err != nil {

View File

@ -82,7 +82,7 @@ func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string,
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg))
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
}
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("decode embedding response: %w", err)

View File

@ -41,6 +41,20 @@ type EmbeddingBatch struct {
Vectors [][]float32
}
type HTTPError struct {
StatusCode int
Body string
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
}
func IsRateLimitError(err error) bool {
var httpErr *HTTPError
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
}
type CheckResult struct {
Provider string
Model string

View File

@ -172,6 +172,27 @@ func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
require.False(t, result.Probed)
}
func TestProviderExposesRateLimitErrors(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "rate limited", http.StatusTooManyRequests)
}))
defer server.Close()
provider, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)
_, err = provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "HTTP 429")
require.True(t, IsRateLimitError(err))
}
func TestProviderRejectsInvalidResponses(t *testing.T) {
t.Parallel()

View File

@ -0,0 +1,549 @@
package store
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"strings"
"time"
"github.com/steipete/discrawl/internal/embed"
)
const (
EmbeddingInputVersion = "message_normalized_v1"
defaultEmbedLimit = 1000
maxEmbeddingAttempts = 3
maxStoredErrorChars = 500
embeddingLockTimeout = 15 * time.Minute
)
type EmbeddingDrainOptions struct {
Provider string
Model string
InputVersion string
Limit int
BatchSize int
MaxInputChars int
Now func() time.Time
}
type EmbeddingDrainStats struct {
Processed int `json:"processed"`
Succeeded int `json:"succeeded"`
Failed int `json:"failed"`
Skipped int `json:"skipped"`
Requeued int `json:"requeued,omitempty"`
RemainingBacklog int `json:"remaining_backlog"`
Provider string `json:"provider"`
Model string `json:"model"`
InputVersion string `json:"input_version"`
RateLimited bool `json:"rate_limited,omitempty"`
}
type embeddingJob struct {
MessageID string
NormalizedContent string
Attempts int
Provider string
Model string
InputVersion string
}
func DefaultEmbedLimit() int {
return defaultEmbedLimit
}
func (s *Store) DrainEmbeddingJobs(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions) (EmbeddingDrainStats, error) {
opts = normalizeEmbeddingDrainOptions(opts)
stats := EmbeddingDrainStats{
Provider: opts.Provider,
Model: opts.Model,
InputVersion: opts.InputVersion,
}
if provider == nil {
return stats, errors.New("embedding provider is nil")
}
now := opts.Now()
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
jobs, err := s.pendingEmbeddingJobs(ctx, opts.Limit, staleBefore)
if err != nil {
return stats, err
}
var batch []embeddingJob
flush := func() error {
if len(batch) == 0 {
return nil
}
rateLimited, err := s.processEmbeddingBatch(ctx, provider, opts, batch, &stats)
batch = batch[:0]
if err != nil {
return err
}
if rateLimited {
stats.RateLimited = true
}
return nil
}
for _, job := range jobs {
if !sameEmbeddingIdentity(job, opts) {
resetAttempts := !emptyEmbeddingIdentity(job)
if err := s.resetEmbeddingJobIdentity(ctx, job.MessageID, opts, resetAttempts); err != nil {
return stats, err
}
job.Provider = opts.Provider
job.Model = opts.Model
job.InputVersion = opts.InputVersion
if resetAttempts {
job.Attempts = 0
}
}
if strings.TrimSpace(job.NormalizedContent) == "" {
if err := s.markEmbeddingJobsDone(ctx, opts, []embeddingJob{job}); err != nil {
return stats, err
}
stats.Processed++
stats.Skipped++
continue
}
batch = append(batch, job)
if len(batch) >= opts.BatchSize {
if err := flush(); err != nil {
return stats, err
}
if stats.RateLimited {
break
}
}
}
if !stats.RateLimited {
if err := flush(); err != nil {
return stats, err
}
}
stats.RemainingBacklog, err = s.EmbeddingBacklog(ctx)
if err != nil {
return stats, err
}
return stats, nil
}
func normalizeEmbeddingDrainOptions(opts EmbeddingDrainOptions) EmbeddingDrainOptions {
opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider))
opts.Model = strings.TrimSpace(opts.Model)
opts.InputVersion = strings.TrimSpace(opts.InputVersion)
if opts.InputVersion == "" {
opts.InputVersion = EmbeddingInputVersion
}
if opts.Limit <= 0 {
opts.Limit = defaultEmbedLimit
}
if opts.BatchSize <= 0 {
opts.BatchSize = embed.DefaultBatchSize
}
if opts.BatchSize > opts.Limit {
opts.BatchSize = opts.Limit
}
if opts.MaxInputChars <= 0 {
opts.MaxInputChars = embed.DefaultMaxInputChars
}
if opts.Now == nil {
opts.Now = func() time.Time { return time.Now().UTC() }
}
return opts
}
func sameEmbeddingIdentity(job embeddingJob, opts EmbeddingDrainOptions) bool {
return job.Provider == opts.Provider && job.Model == opts.Model && job.InputVersion == opts.InputVersion
}
func emptyEmbeddingIdentity(job embeddingJob) bool {
return job.Provider == "" && job.Model == "" && job.InputVersion == ""
}
func (s *Store) pendingEmbeddingJobs(ctx context.Context, limit int, staleBefore string) ([]embeddingJob, error) {
rows, err := s.db.QueryContext(ctx, `
select
j.message_id,
m.normalized_content,
j.attempts,
j.provider,
j.model,
j.input_version
from embedding_jobs j
join messages m on m.id = j.message_id
where j.state = 'pending'
and (j.locked_at is null or j.locked_at = '' or j.locked_at < ?)
order by j.updated_at, j.message_id
limit ?
`, staleBefore, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var jobs []embeddingJob
for rows.Next() {
var job embeddingJob
if err := rows.Scan(&job.MessageID, &job.NormalizedContent, &job.Attempts, &job.Provider, &job.Model, &job.InputVersion); err != nil {
return nil, err
}
jobs = append(jobs, job)
}
return jobs, rows.Err()
}
func (s *Store) resetEmbeddingJobIdentity(ctx context.Context, messageID string, opts EmbeddingDrainOptions, resetAttempts bool) error {
if resetAttempts {
_, err := s.db.ExecContext(ctx, `
update embedding_jobs
set provider = ?,
model = ?,
input_version = ?,
attempts = 0,
last_error = '',
locked_at = null,
updated_at = ?
where message_id = ?
`, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID)
return err
}
_, err := s.db.ExecContext(ctx, `
update embedding_jobs
set provider = ?,
model = ?,
input_version = ?,
last_error = '',
locked_at = null,
updated_at = ?
where message_id = ?
`, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID)
return err
}
func (s *Store) processEmbeddingBatch(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions, jobs []embeddingJob, stats *EmbeddingDrainStats) (bool, error) {
now := opts.Now()
lockedAt := now.Format(timeLayout)
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
claimed, err := s.lockEmbeddingJobs(ctx, jobs, lockedAt, staleBefore)
if err != nil {
return false, err
}
if len(claimed) == 0 {
return false, nil
}
jobs = claimed
inputs := make([]string, 0, len(jobs))
for _, job := range jobs {
inputs = append(inputs, capRunes(job.NormalizedContent, opts.MaxInputChars))
}
batch, err := provider.Embed(ctx, inputs)
if err != nil {
if embed.IsRateLimitError(err) {
if markErr := s.markEmbeddingJobsRateLimited(ctx, opts, jobs, err); markErr != nil {
return false, markErr
}
stats.Requeued += len(jobs)
return true, nil
}
if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil {
return false, markErr
}
stats.Processed += len(jobs)
stats.Failed += len(jobs)
return embed.IsRateLimitError(err), nil
}
dimensions, err := validateEmbeddingBatch(batch, len(jobs))
if err != nil {
if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil {
return false, markErr
}
stats.Processed += len(jobs)
stats.Failed += len(jobs)
return false, nil
}
if err := s.storeEmbeddingBatch(ctx, opts, jobs, batch.Vectors, dimensions); err != nil {
return false, err
}
stats.Processed += len(jobs)
stats.Succeeded += len(jobs)
return false, nil
}
func (s *Store) lockEmbeddingJobs(ctx context.Context, jobs []embeddingJob, lockedAt, staleBefore string) ([]embeddingJob, error) {
if len(jobs) == 0 {
return nil, nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer rollback(tx)
claimed := make([]embeddingJob, 0, len(jobs))
for _, job := range jobs {
result, err := tx.ExecContext(ctx, `
update embedding_jobs
set locked_at = ?, updated_at = ?
where message_id = ?
and state = 'pending'
and (locked_at is null or locked_at = '' or locked_at < ?)
`, lockedAt, lockedAt, job.MessageID, staleBefore)
if err != nil {
return nil, err
}
rows, err := result.RowsAffected()
if err != nil {
return nil, err
}
if rows == 1 {
claimed = append(claimed, job)
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return claimed, nil
}
func validateEmbeddingBatch(batch embed.EmbeddingBatch, expected int) (int, error) {
if len(batch.Vectors) != expected {
return 0, fmt.Errorf("embedding provider returned %d vectors for %d inputs", len(batch.Vectors), expected)
}
dimensions := batch.Dimensions
for _, vector := range batch.Vectors {
if len(vector) == 0 {
return 0, errors.New("embedding provider returned an empty vector")
}
if dimensions == 0 {
dimensions = len(vector)
continue
}
if len(vector) != dimensions {
return 0, fmt.Errorf("embedding provider dimensions mismatch: got %d want %d", len(vector), dimensions)
}
}
return dimensions, nil
}
func (s *Store) storeEmbeddingBatch(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, vectors [][]float32, dimensions int) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer rollback(tx)
embeddedAt := opts.Now().Format(timeLayout)
for i, job := range jobs {
blob, err := EncodeEmbeddingVector(vectors[i])
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, `
insert into message_embeddings(
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
) values(?, ?, ?, ?, ?, ?, ?)
on conflict(message_id, provider, model, input_version) do update set
dimensions = excluded.dimensions,
embedding_blob = excluded.embedding_blob,
embedded_at = excluded.embedded_at
`, job.MessageID, opts.Provider, opts.Model, opts.InputVersion, dimensions, blob, embeddedAt); err != nil {
return err
}
if _, err := tx.ExecContext(ctx, `
update embedding_jobs
set state = 'done',
attempts = 0,
provider = ?,
model = ?,
input_version = ?,
last_error = '',
locked_at = null,
updated_at = ?
where message_id = ?
`, opts.Provider, opts.Model, opts.InputVersion, embeddedAt, job.MessageID); err != nil {
return err
}
}
return tx.Commit()
}
func (s *Store) markEmbeddingJobsDone(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer rollback(tx)
now := opts.Now().Format(timeLayout)
for _, job := range jobs {
if _, err := tx.ExecContext(ctx, `delete from message_embeddings where message_id = ?`, job.MessageID); err != nil {
return err
}
if _, err := tx.ExecContext(ctx, `
update embedding_jobs
set state = 'done',
provider = ?,
model = ?,
input_version = ?,
last_error = '',
locked_at = null,
updated_at = ?
where message_id = ?
`, opts.Provider, opts.Model, opts.InputVersion, now, job.MessageID); err != nil {
return err
}
}
return tx.Commit()
}
func (s *Store) markEmbeddingJobsRateLimited(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, cause error) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer rollback(tx)
now := opts.Now().Format(timeLayout)
lastError := trimStoredError(cause)
for _, job := range jobs {
if _, err := tx.ExecContext(ctx, `
update embedding_jobs
set state = 'pending',
provider = ?,
model = ?,
input_version = ?,
last_error = ?,
locked_at = null,
updated_at = ?
where message_id = ?
`, opts.Provider, opts.Model, opts.InputVersion, lastError, now, job.MessageID); err != nil {
return err
}
}
return tx.Commit()
}
func (s *Store) markEmbeddingJobsFailed(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, cause error) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer rollback(tx)
now := opts.Now().Format(timeLayout)
lastError := trimStoredError(cause)
for _, job := range jobs {
attempts := job.Attempts + 1
state := "pending"
if attempts >= maxEmbeddingAttempts {
state = "failed"
}
if _, err := tx.ExecContext(ctx, `
update embedding_jobs
set state = ?,
attempts = ?,
provider = ?,
model = ?,
input_version = ?,
last_error = ?,
locked_at = null,
updated_at = ?
where message_id = ?
`, state, attempts, opts.Provider, opts.Model, opts.InputVersion, lastError, now, job.MessageID); err != nil {
return err
}
}
return tx.Commit()
}
func trimStoredError(err error) string {
if err == nil {
return ""
}
msg := strings.TrimSpace(err.Error())
runes := []rune(msg)
if len(runes) > maxStoredErrorChars {
msg = string(runes[:maxStoredErrorChars])
}
return msg
}
func capRunes(value string, maxChars int) string {
if maxChars <= 0 {
return value
}
runes := []rune(value)
if len(runes) <= maxChars {
return value
}
return string(runes[:maxChars])
}
func EncodeEmbeddingVector(vector []float32) ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4))
for _, value := range vector {
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
return nil, fmt.Errorf("encode embedding vector: %w", err)
}
}
return buf.Bytes(), nil
}
func DecodeEmbeddingVector(blob []byte) ([]float32, error) {
if len(blob)%4 != 0 {
return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob))
}
out := make([]float32, len(blob)/4)
reader := bytes.NewReader(blob)
for i := range out {
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
return nil, fmt.Errorf("decode embedding vector: %w", err)
}
}
return out, nil
}
func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) {
var count int
err := s.db.QueryRowContext(ctx, `select count(*) from embedding_jobs where state = 'pending'`).Scan(&count)
return count, err
}
func (s *Store) RequeueAllEmbeddingJobs(ctx context.Context, opts EmbeddingDrainOptions) (int, error) {
opts = normalizeEmbeddingDrainOptions(opts)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer rollback(tx)
now := opts.Now().Format(timeLayout)
if _, err := tx.ExecContext(ctx, `
insert or ignore into embedding_jobs(
message_id, state, attempts, provider, model, input_version, last_error, locked_at, updated_at
)
select id, 'pending', 0, ?, ?, ?, '', null, ?
from messages
`, opts.Provider, opts.Model, opts.InputVersion, now); err != nil {
return 0, err
}
result, err := tx.ExecContext(ctx, `
update embedding_jobs
set state = 'pending',
attempts = 0,
provider = ?,
model = ?,
input_version = ?,
last_error = '',
locked_at = null,
updated_at = ?
where message_id in (select id from messages)
`, opts.Provider, opts.Model, opts.InputVersion, now)
if err != nil {
return 0, err
}
if err := tx.Commit(); err != nil {
return 0, err
}
affected, err := result.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}

View File

@ -211,6 +211,7 @@ func (s *Store) migrate(ctx context.Context) error {
if err := s.setSchemaVersion(ctx, storeSchemaVersion); err != nil {
return err
}
currentVersion = storeSchemaVersion
}
if version, err := s.schemaVersion(ctx); err != nil {
return err
@ -368,8 +369,23 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
message_id text primary key,
state text not null,
attempts integer not null default 0,
provider text not null default '',
model text not null default '',
input_version text not null default '',
last_error text not null default '',
locked_at text,
updated_at text not null
);`,
`create table if not exists message_embeddings (
message_id text not null,
provider text not null,
model text not null,
input_version text not null,
dimensions integer not null,
embedding_blob blob not null,
embedded_at text not null,
primary key (message_id, provider, model, input_version)
);`,
`create virtual table if not exists message_fts using fts5(
message_id unindexed,
guild_id unindexed,
@ -402,6 +418,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
`create index if not exists idx_mentions_target on mention_events(target_type, target_id, event_at);`,
`create index if not exists idx_mentions_author on mention_events(author_id, event_at);`,
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,
}
for _, stmt := range stmts {
if _, err := tx.ExecContext(ctx, stmt); err != nil {
@ -417,12 +434,43 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
return err
}
defer rollback(tx)
for _, column := range []struct {
name string
sql string
}{
{"provider", `alter table embedding_jobs add column provider text not null default ''`},
{"model", `alter table embedding_jobs add column model text not null default ''`},
{"input_version", `alter table embedding_jobs add column input_version text not null default ''`},
{"last_error", `alter table embedding_jobs add column last_error text not null default ''`},
{"locked_at", `alter table embedding_jobs add column locked_at text`},
} {
ok, err := columnExists(ctx, tx, "embedding_jobs", column.name)
if err != nil {
return err
}
if !ok {
if _, err := tx.ExecContext(ctx, column.sql); err != nil {
return fmt.Errorf("add embedding_jobs.%s: %w", column.name, err)
}
}
}
stmts := []string{
`create table if not exists message_embeddings (
message_id text not null,
provider text not null,
model text not null,
input_version text not null,
dimensions integer not null,
embedding_blob blob not null,
embedded_at text not null,
primary key (message_id, provider, model, input_version)
);`,
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
`create index if not exists idx_mentions_guild_event on mention_events(guild_id, event_at, event_id);`,
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,
}
for _, stmt := range stmts {
if _, err := tx.ExecContext(ctx, stmt); err != nil {
@ -432,6 +480,27 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
return tx.Commit()
}
func columnExists(ctx context.Context, tx *sql.Tx, table, column string) (bool, error) {
rows, err := tx.QueryContext(ctx, `pragma table_info(`+table+`)`)
if err != nil {
return false, fmt.Errorf("inspect %s columns: %w", table, err)
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var cid int
var name, typ string
var notNull int
var defaultValue sql.NullString
var pk int
if err := rows.Scan(&cid, &name, &typ, &notNull, &defaultValue, &pk); err != nil {
return false, err
}
if name == column {
return true, nil
}
}
return false, rows.Err()
}
func (s *Store) ensureFTSRowIDs(ctx context.Context) error {
var version sql.NullString
err := s.db.QueryRowContext(ctx, `

View File

@ -325,6 +325,88 @@ func TestOpenBackfillsMissingSchemaVersion(t *testing.T) {
require.Equal(t, storeSchemaVersion, version)
}
func TestOpenMigratesSchemaV1ToV2(t *testing.T) {
t.Parallel()
ctx := context.Background()
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
require.NoError(t, createV1Schema(ctx, dbPath))
db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err)
_, err = db.ExecContext(ctx, `
insert into messages(
id, guild_id, channel_id, message_type, created_at, content,
normalized_content, raw_json, updated_at
) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z')
`)
require.NoError(t, err)
_, err = db.ExecContext(ctx, `
insert into embedding_jobs(message_id, state, attempts, updated_at)
values('m1', 'pending', 1, '2026-01-01T00:00:00Z')
`)
require.NoError(t, err)
require.NoError(t, db.Close())
s, err := Open(ctx, dbPath)
require.NoError(t, err)
defer func() { _ = s.Close() }()
var version int
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version))
require.Equal(t, 2, version)
_, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"", "", "", "", ""}}, rows)
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestOpenMigratesUnversionedV1SchemaToV2(t *testing.T) {
t.Parallel()
ctx := context.Background()
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
require.NoError(t, createV1Schema(ctx, dbPath))
db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err)
_, err = db.ExecContext(ctx, `
insert into messages(
id, guild_id, channel_id, message_type, created_at, content,
normalized_content, raw_json, updated_at
) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z')
`)
require.NoError(t, err)
_, err = db.ExecContext(ctx, `
insert into embedding_jobs(message_id, state, attempts, updated_at)
values('m1', 'pending', 1, '2026-01-01T00:00:00Z')
`)
require.NoError(t, err)
_, err = db.ExecContext(ctx, `pragma user_version = 0`)
require.NoError(t, err)
require.NoError(t, db.Close())
s, err := Open(ctx, dbPath)
require.NoError(t, err)
defer func() { _ = s.Close() }()
var version int
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version))
require.Equal(t, 2, version)
_, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"", "", "", "", ""}}, rows)
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestReadOnlyQueryGuards(t *testing.T) {
t.Parallel()
@ -342,6 +424,142 @@ func TestReadOnlyQueryGuards(t *testing.T) {
require.Error(t, err)
}
func createV1Schema(ctx context.Context, path string) error {
db, err := sql.Open("sqlite", path)
if err != nil {
return err
}
defer func() { _ = db.Close() }()
stmts := []string{
`create table guilds (
id text primary key,
name text not null,
icon text,
raw_json text not null,
updated_at text not null
);`,
`create table channels (
id text primary key,
guild_id text not null,
parent_id text,
kind text not null,
name text not null,
topic text,
position integer,
is_nsfw integer not null default 0,
is_archived integer not null default 0,
is_locked integer not null default 0,
is_private_thread integer not null default 0,
thread_parent_id text,
archive_timestamp text,
raw_json text not null,
updated_at text not null
);`,
`create table members (
guild_id text not null,
user_id text not null,
username text not null,
global_name text,
display_name text,
nick text,
discriminator text,
avatar text,
bot integer not null default 0,
joined_at text,
role_ids_json text not null,
raw_json text not null,
updated_at text not null,
primary key (guild_id, user_id)
);`,
`create table messages (
id text primary key,
guild_id text not null,
channel_id text not null,
author_id text,
message_type integer not null,
created_at text not null,
edited_at text,
deleted_at text,
content text not null,
normalized_content text not null,
reply_to_message_id text,
pinned integer not null default 0,
has_attachments integer not null default 0,
raw_json text not null,
updated_at text not null
);`,
`create table message_events (
event_id integer primary key autoincrement,
guild_id text not null,
channel_id text not null,
message_id text not null,
event_type text not null,
event_at text not null,
payload_json text not null
);`,
`create table message_attachments (
attachment_id text primary key,
message_id text not null,
guild_id text not null,
channel_id text not null,
author_id text,
filename text not null,
content_type text,
size integer not null default 0,
url text,
proxy_url text,
text_content text not null default '',
updated_at text not null
);`,
`create table mention_events (
event_id integer primary key autoincrement,
message_id text not null,
guild_id text not null,
channel_id text not null,
author_id text,
target_type text not null,
target_id text not null,
target_name text not null default '',
event_at text not null
);`,
`create table sync_state (
scope text primary key,
cursor text,
updated_at text not null
);`,
`create table embedding_jobs (
message_id text primary key,
state text not null,
attempts integer not null default 0,
updated_at text not null
);`,
`create virtual table message_fts using fts5(
message_id unindexed,
guild_id unindexed,
channel_id unindexed,
author_id unindexed,
author_name,
channel_name,
content
);`,
`create virtual table member_fts using fts5(
member_key unindexed,
guild_id unindexed,
user_id unindexed,
username,
display_name,
profile_text
);`,
`pragma user_version = 1;`,
}
for _, stmt := range stmts {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}
func TestQueryAndExec(t *testing.T) {
t.Parallel()

View File

@ -3,12 +3,15 @@ package store
import (
"context"
"database/sql"
"errors"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/steipete/discrawl/internal/embed"
)
func TestUpsertMessagesBatch(t *testing.T) {
@ -123,6 +126,387 @@ func TestUpsertMessageWithEmbeddingsQueuesJob(t *testing.T) {
require.Equal(t, "1", rows[0][0])
}
func TestUpsertMessageWithEmbeddingsQueuesExistingMessageWithoutJob(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessage(ctx, record))
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"pending", "0"}}, rows)
}
func TestDrainEmbeddingJobsStoresVectorsAndSkipsEmptyInput(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
now := time.Now().UTC().Format(time.RFC3339Nano)
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "abcdef世界",
NormalizedContent: "abcdef世界",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: now,
Content: "",
NormalizedContent: " ",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
provider := &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{
Vectors: [][]float32{{1.25, 2.5}},
}},
}
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 2,
MaxInputChars: 7,
})
require.NoError(t, err)
require.Equal(t, 2, stats.Processed)
require.Equal(t, 1, stats.Succeeded)
require.Equal(t, 1, stats.Skipped)
require.Equal(t, 0, stats.RemainingBacklog)
require.Equal(t, [][]string{{"abcdef世"}}, provider.inputs)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, provider, model, input_version, dimensions from message_embeddings")
require.NoError(t, err)
require.Equal(t, [][]string{{"m1", "ollama", "nomic-embed-text", EmbeddingInputVersion, "2"}}, rows)
var blob []byte
require.NoError(t, s.DB().QueryRowContext(ctx, `select embedding_blob from message_embeddings where message_id = 'm1'`).Scan(&blob))
vector, err := DecodeEmbeddingVector(blob)
require.NoError(t, err)
require.Equal(t, []float32{1.25, 2.5}, vector)
_, rows, err = s.ReadOnlyQuery(ctx, "select message_id, state, provider, model, input_version from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
{"m2", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
}, rows)
}
func TestUpsertMessageWithEmbeddingsDoesNotRequeueUnchangedDoneJob(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
require.Equal(t, 1, stats.Succeeded)
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"done", "0", ""}}, rows)
backlog, err := s.EmbeddingBacklog(ctx)
require.NoError(t, err)
require.Equal(t, 0, backlog)
record.NormalizedContent = "hello updated"
record.Content = "hello updated"
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, rows, err = s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"pending", "0", ""}}, rows)
}
func TestDrainEmbeddingJobsFailsWholeBatchOnDimensionMismatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{
Dimensions: 3,
Vectors: [][]float32{{1, 2}},
}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
require.Equal(t, 1, stats.Failed)
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, "pending", rows[0][0])
require.Equal(t, "1", rows[0][1])
require.Contains(t, rows[0][2], "dimensions mismatch")
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
}
func TestDrainEmbeddingJobsMarksFailedAfterMaxAttempts(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
_, err = s.DB().ExecContext(ctx, `update embedding_jobs set attempts = 2 where message_id = 'm1'`)
require.NoError(t, err)
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{err: errors.New("provider down")}, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
})
require.NoError(t, err)
require.Equal(t, 1, stats.Failed)
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"failed", "3", "provider down"}}, rows)
}
func TestDrainEmbeddingJobsStopsOnRateLimit(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
}
provider := &fakeEmbeddingProvider{err: &embed.HTTPError{StatusCode: 429, Body: "slow down"}}
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 1,
})
require.NoError(t, err)
require.True(t, stats.RateLimited)
require.Equal(t, 0, stats.Processed)
require.Equal(t, 0, stats.Failed)
require.Equal(t, 1, stats.Requeued)
require.Equal(t, 2, stats.RemainingBacklog)
require.Len(t, provider.inputs, 1)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, last_error, coalesce(locked_at, '') from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "pending", "0", "embedding request failed with HTTP 429: slow down", ""},
{"m2", "pending", "0", "", ""},
}, rows)
}
func TestDrainEmbeddingJobsDeletesStaleVectorsForEmptyContent(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
record := MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
_, err = s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
require.NoError(t, err)
record.Content = ""
record.NormalizedContent = " "
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{}, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
Limit: 10,
BatchSize: 1,
})
require.NoError(t, err)
require.Equal(t, 1, stats.Processed)
require.Equal(t, 1, stats.Skipped)
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "0", rows[0][0])
_, rows, err = s.ReadOnlyQuery(ctx, "select state, provider, model from embedding_jobs where message_id = 'm1'")
require.NoError(t, err)
require.Equal(t, [][]string{{"done", "ollama", "nomic-embed-text"}}, rows)
}
func TestPendingEmbeddingJobsSkipsFreshLocks(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, WriteOptions{EnqueueEmbedding: true}))
now := time.Now().UTC()
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
jobs, err := s.pendingEmbeddingJobs(ctx, 10, staleBefore)
require.NoError(t, err)
require.Len(t, jobs, 1)
claimed, err := s.lockEmbeddingJobs(ctx, jobs, now.Format(timeLayout), staleBefore)
require.NoError(t, err)
require.Len(t, claimed, 1)
jobs, err = s.pendingEmbeddingJobs(ctx, 10, staleBefore)
require.NoError(t, err)
require.Empty(t, jobs)
claimed, err = s.lockEmbeddingJobs(ctx, claimed, now.Add(time.Minute).Format(timeLayout), staleBefore)
require.NoError(t, err)
require.Empty(t, claimed)
}
func TestRequeueAllEmbeddingJobsUsesCurrentIdentity(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}))
}
_, err = s.DB().ExecContext(ctx, `
insert into embedding_jobs(message_id, state, attempts, provider, model, input_version, last_error, updated_at)
values('m1', 'failed', 3, 'old', 'old-model', 'old-input', 'old error', ?)
`, time.Now().UTC().Format(timeLayout))
require.NoError(t, err)
requeued, err := s.RequeueAllEmbeddingJobs(ctx, EmbeddingDrainOptions{
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
})
require.NoError(t, err)
require.Equal(t, 2, requeued)
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, provider, model, input_version, last_error from embedding_jobs order by message_id")
require.NoError(t, err)
require.Equal(t, [][]string{
{"m1", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
{"m2", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
}, rows)
}
func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) {
t.Parallel()
@ -160,6 +544,26 @@ func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) {
require.Equal(t, "8", rows[0][0])
}
type fakeEmbeddingProvider struct {
batches []embed.EmbeddingBatch
err error
inputs [][]string
}
func (f *fakeEmbeddingProvider) Embed(_ context.Context, inputs []string) (embed.EmbeddingBatch, error) {
copied := append([]string(nil), inputs...)
f.inputs = append(f.inputs, copied)
if f.err != nil {
return embed.EmbeddingBatch{}, f.err
}
if len(f.batches) == 0 {
return embed.EmbeddingBatch{}, nil
}
batch := f.batches[0]
f.batches = f.batches[1:]
return batch, nil
}
func TestMessageFTSUsesSnowflakeRowID(t *testing.T) {
t.Parallel()

View File

@ -287,6 +287,30 @@ func (s *Store) UpsertMessages(ctx context.Context, messages []MessageMutation)
func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opts WriteOptions) error {
now := time.Now().UTC().Format(timeLayout)
var previousNormalized sql.NullString
previousErr := sql.ErrNoRows
jobExists := false
if opts.EnqueueEmbedding {
previousErr = tx.QueryRowContext(ctx, `
select normalized_content
from messages
where id = ?
`, message.ID).Scan(&previousNormalized)
if previousErr != nil && previousErr != sql.ErrNoRows {
return previousErr
}
if previousErr == nil {
var existingJobs int
if err := tx.QueryRowContext(ctx, `
select count(*)
from embedding_jobs
where message_id = ?
`, message.ID).Scan(&existingJobs); err != nil {
return err
}
jobExists = existingJobs > 0
}
}
if _, err := tx.ExecContext(ctx, `
insert into messages(
id, guild_id, channel_id, author_id, message_type, created_at, edited_at, deleted_at,
@ -323,11 +347,17 @@ func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opt
return err
}
}
if opts.EnqueueEmbedding {
queueEmbedding := opts.EnqueueEmbedding && (previousErr == sql.ErrNoRows || previousNormalized.String != message.NormalizedContent || !jobExists)
if queueEmbedding {
if _, err := tx.ExecContext(ctx, `
insert into embedding_jobs(message_id, state, attempts, updated_at)
values(?, 'pending', 0, ?)
on conflict(message_id) do nothing
on conflict(message_id) do update set
state = 'pending',
attempts = 0,
last_error = '',
locked_at = null,
updated_at = excluded.updated_at
`, message.ID, now); err != nil {
return err
}