feat(embed): add embedding job drain
* feat(embed): add embedding job drain * fix(embed): migrate legacy jobs and requeue rate limits safely --------- Co-authored-by: Vincent Koc <vincentkoc@ieee.org>
This commit is contained in:
parent
2f07416702
commit
ad4c897371
@ -8,6 +8,7 @@ All notable changes to `discrawl` will be documented in this file.
|
||||
- `messages` and `mentions` now use composite read-path indexes so larger archives spend less time sorting/filtering common guild, channel, and author queries
|
||||
- normalized message text is now sanitized before it reaches SQLite and FTS5, repairing malformed UTF-8 and stripping invisible/control-character noise that can poison search content
|
||||
- local embedding providers now support OpenAI-compatible endpoints, Ollama, and llama.cpp, and `doctor` can probe the configured provider before you queue vectors
|
||||
- `embed` now drains the queued embedding backlog in bounded batches, requeues safely on provider throttling, and drops stale stored vectors when messages no longer have embeddable content
|
||||
|
||||
## 0.3.0 - 2026-04-21
|
||||
|
||||
|
||||
@ -168,6 +168,61 @@ func (r *runtime) runStatus(args []string) error {
|
||||
return r.print(status)
|
||||
}
|
||||
|
||||
func (r *runtime) runEmbed(args []string) error {
|
||||
fs := flag.NewFlagSet("embed", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
limit := fs.Int("limit", store.DefaultEmbedLimit(), "")
|
||||
batchSize := fs.Int("batch-size", r.cfg.Search.Embeddings.BatchSize, "")
|
||||
rebuild := fs.Bool("rebuild", false, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return usageErr(err)
|
||||
}
|
||||
if fs.NArg() != 0 {
|
||||
return usageErr(fmt.Errorf("embed takes no positional arguments"))
|
||||
}
|
||||
if *limit <= 0 {
|
||||
return usageErr(fmt.Errorf("--limit must be positive"))
|
||||
}
|
||||
if *batchSize <= 0 {
|
||||
return usageErr(fmt.Errorf("--batch-size must be positive"))
|
||||
}
|
||||
if !r.cfg.Search.Embeddings.Enabled {
|
||||
return usageErr(fmt.Errorf("embeddings are disabled in config"))
|
||||
}
|
||||
providerFactory := r.newEmbed
|
||||
if providerFactory == nil {
|
||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||
return embed.NewProvider(cfg)
|
||||
}
|
||||
}
|
||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||
if err != nil {
|
||||
return configErr(err)
|
||||
}
|
||||
opts := store.EmbeddingDrainOptions{
|
||||
Provider: r.cfg.Search.Embeddings.Provider,
|
||||
Model: r.cfg.Search.Embeddings.Model,
|
||||
InputVersion: store.EmbeddingInputVersion,
|
||||
Limit: *limit,
|
||||
BatchSize: *batchSize,
|
||||
MaxInputChars: r.cfg.Search.Embeddings.MaxInputChars,
|
||||
Now: r.now,
|
||||
}
|
||||
requeued := 0
|
||||
if *rebuild {
|
||||
requeued, err = r.store.RequeueAllEmbeddingJobs(r.ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stats, err := r.store.DrainEmbeddingJobs(r.ctx, provider, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.Requeued = requeued
|
||||
return r.print(stats)
|
||||
}
|
||||
|
||||
func (r *runtime) runDoctor(args []string) error {
|
||||
if len(args) != 0 {
|
||||
return usageErr(fmt.Errorf("doctor takes no arguments"))
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/steipete/discrawl/internal/config"
|
||||
"github.com/steipete/discrawl/internal/discord"
|
||||
"github.com/steipete/discrawl/internal/embed"
|
||||
"github.com/steipete/discrawl/internal/share"
|
||||
"github.com/steipete/discrawl/internal/store"
|
||||
"github.com/steipete/discrawl/internal/syncer"
|
||||
@ -96,6 +97,7 @@ type runtime struct {
|
||||
openStore func(context.Context, string) (*store.Store, error)
|
||||
newDiscord func(config.Config) (discordClient, error)
|
||||
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
|
||||
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
@ -130,6 +132,8 @@ func (r *runtime) dispatch(rest []string) error {
|
||||
return r.withServicesAuto(hasBoolFlag(rest[1:], "--sync"), true, func() error { return r.runMessages(rest[1:]) })
|
||||
case "mentions":
|
||||
return r.withServices(false, func() error { return r.runMentions(rest[1:]) })
|
||||
case "embed":
|
||||
return r.withServices(false, func() error { return r.runEmbed(rest[1:]) })
|
||||
case "sql":
|
||||
return r.withServices(false, func() error { return r.runSQL(rest[1:]) })
|
||||
case "members":
|
||||
|
||||
@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -489,6 +490,70 @@ func runGit(t *testing.T, dir string, args ...string) {
|
||||
require.NoError(t, err, string(out))
|
||||
}
|
||||
|
||||
func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/embeddings", r.URL.Path)
|
||||
var req struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
require.Len(t, req.Input, 1)
|
||||
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1,2]}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = server.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
for _, id := range []string{"m1", "m2"} {
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, store.MessageRecord{
|
||||
ID: id,
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}, store.WriteOptions{EnqueueEmbedding: true}))
|
||||
}
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--limit", "1"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "processed=1")
|
||||
require.Contains(t, out.String(), "succeeded=1")
|
||||
require.Contains(t, out.String(), "remaining_backlog=1")
|
||||
require.Contains(t, out.String(), "provider=openai_compatible")
|
||||
|
||||
s, err = store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1", rows[0][0])
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
out.Reset()
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--rebuild", "--limit", "1"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "processed=1")
|
||||
require.Contains(t, out.String(), "succeeded=1")
|
||||
require.Contains(t, out.String(), "remaining_backlog=1")
|
||||
require.Contains(t, out.String(), "requeued=2")
|
||||
}
|
||||
|
||||
type fakeDiscordClient struct {
|
||||
guilds []*discordgo.UserGuild
|
||||
self *discordgo.User
|
||||
|
||||
@ -80,6 +80,7 @@ Commands:
|
||||
search
|
||||
messages
|
||||
mentions
|
||||
embed
|
||||
sql
|
||||
members
|
||||
channels
|
||||
@ -108,6 +109,21 @@ func printHuman(w io.Writer, value any) error {
|
||||
v.DBPath, v.GuildCount, v.ChannelCount, v.ThreadCount, v.MessageCount, v.MemberCount, v.EmbeddingBacklog,
|
||||
formatTime(v.LastSyncAt), formatTime(v.LastTailEventAt))
|
||||
return err
|
||||
case store.EmbeddingDrainStats:
|
||||
_, err := fmt.Fprintf(w, "processed=%d\nsucceeded=%d\nfailed=%d\nskipped=%d\nremaining_backlog=%d\nprovider=%s\nmodel=%s\ninput_version=%s\n",
|
||||
v.Processed, v.Succeeded, v.Failed, v.Skipped, v.RemainingBacklog, v.Provider, v.Model, v.InputVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Requeued > 0 {
|
||||
if _, err := fmt.Fprintf(w, "requeued=%d\n", v.Requeued); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if v.RateLimited {
|
||||
_, err = fmt.Fprintln(w, "rate_limited=true")
|
||||
}
|
||||
return err
|
||||
case []store.SearchResult:
|
||||
for _, row := range v {
|
||||
if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s\n\n", row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Content); err != nil {
|
||||
|
||||
@ -82,7 +82,7 @@ func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string,
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg))
|
||||
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
return fmt.Errorf("decode embedding response: %w", err)
|
||||
|
||||
@ -41,6 +41,20 @@ type EmbeddingBatch struct {
|
||||
Vectors [][]float32
|
||||
}
|
||||
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
func IsRateLimitError(err error) bool {
|
||||
var httpErr *HTTPError
|
||||
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
type CheckResult struct {
|
||||
Provider string
|
||||
Model string
|
||||
|
||||
@ -172,6 +172,27 @@ func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
|
||||
require.False(t, result.Probed)
|
||||
}
|
||||
|
||||
func TestProviderExposesRateLimitErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider, err := NewProvider(config.EmbeddingsConfig{
|
||||
Provider: ProviderOpenAICompatible,
|
||||
Model: "local-model",
|
||||
BaseURL: server.URL,
|
||||
RequestTimeout: "5s",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Embed(context.Background(), []string{"one"})
|
||||
require.ErrorContains(t, err, "HTTP 429")
|
||||
require.True(t, IsRateLimitError(err))
|
||||
}
|
||||
|
||||
func TestProviderRejectsInvalidResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
549
internal/store/embeddings.go
Normal file
549
internal/store/embeddings.go
Normal file
@ -0,0 +1,549 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/steipete/discrawl/internal/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
EmbeddingInputVersion = "message_normalized_v1"
|
||||
defaultEmbedLimit = 1000
|
||||
maxEmbeddingAttempts = 3
|
||||
maxStoredErrorChars = 500
|
||||
embeddingLockTimeout = 15 * time.Minute
|
||||
)
|
||||
|
||||
type EmbeddingDrainOptions struct {
|
||||
Provider string
|
||||
Model string
|
||||
InputVersion string
|
||||
Limit int
|
||||
BatchSize int
|
||||
MaxInputChars int
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
type EmbeddingDrainStats struct {
|
||||
Processed int `json:"processed"`
|
||||
Succeeded int `json:"succeeded"`
|
||||
Failed int `json:"failed"`
|
||||
Skipped int `json:"skipped"`
|
||||
Requeued int `json:"requeued,omitempty"`
|
||||
RemainingBacklog int `json:"remaining_backlog"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
InputVersion string `json:"input_version"`
|
||||
RateLimited bool `json:"rate_limited,omitempty"`
|
||||
}
|
||||
|
||||
type embeddingJob struct {
|
||||
MessageID string
|
||||
NormalizedContent string
|
||||
Attempts int
|
||||
Provider string
|
||||
Model string
|
||||
InputVersion string
|
||||
}
|
||||
|
||||
func DefaultEmbedLimit() int {
|
||||
return defaultEmbedLimit
|
||||
}
|
||||
|
||||
func (s *Store) DrainEmbeddingJobs(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions) (EmbeddingDrainStats, error) {
|
||||
opts = normalizeEmbeddingDrainOptions(opts)
|
||||
stats := EmbeddingDrainStats{
|
||||
Provider: opts.Provider,
|
||||
Model: opts.Model,
|
||||
InputVersion: opts.InputVersion,
|
||||
}
|
||||
if provider == nil {
|
||||
return stats, errors.New("embedding provider is nil")
|
||||
}
|
||||
now := opts.Now()
|
||||
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
|
||||
jobs, err := s.pendingEmbeddingJobs(ctx, opts.Limit, staleBefore)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
var batch []embeddingJob
|
||||
flush := func() error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
rateLimited, err := s.processEmbeddingBatch(ctx, provider, opts, batch, &stats)
|
||||
batch = batch[:0]
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rateLimited {
|
||||
stats.RateLimited = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
for _, job := range jobs {
|
||||
if !sameEmbeddingIdentity(job, opts) {
|
||||
resetAttempts := !emptyEmbeddingIdentity(job)
|
||||
if err := s.resetEmbeddingJobIdentity(ctx, job.MessageID, opts, resetAttempts); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
job.Provider = opts.Provider
|
||||
job.Model = opts.Model
|
||||
job.InputVersion = opts.InputVersion
|
||||
if resetAttempts {
|
||||
job.Attempts = 0
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(job.NormalizedContent) == "" {
|
||||
if err := s.markEmbeddingJobsDone(ctx, opts, []embeddingJob{job}); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
stats.Processed++
|
||||
stats.Skipped++
|
||||
continue
|
||||
}
|
||||
batch = append(batch, job)
|
||||
if len(batch) >= opts.BatchSize {
|
||||
if err := flush(); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
if stats.RateLimited {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !stats.RateLimited {
|
||||
if err := flush(); err != nil {
|
||||
return stats, err
|
||||
}
|
||||
}
|
||||
stats.RemainingBacklog, err = s.EmbeddingBacklog(ctx)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func normalizeEmbeddingDrainOptions(opts EmbeddingDrainOptions) EmbeddingDrainOptions {
|
||||
opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider))
|
||||
opts.Model = strings.TrimSpace(opts.Model)
|
||||
opts.InputVersion = strings.TrimSpace(opts.InputVersion)
|
||||
if opts.InputVersion == "" {
|
||||
opts.InputVersion = EmbeddingInputVersion
|
||||
}
|
||||
if opts.Limit <= 0 {
|
||||
opts.Limit = defaultEmbedLimit
|
||||
}
|
||||
if opts.BatchSize <= 0 {
|
||||
opts.BatchSize = embed.DefaultBatchSize
|
||||
}
|
||||
if opts.BatchSize > opts.Limit {
|
||||
opts.BatchSize = opts.Limit
|
||||
}
|
||||
if opts.MaxInputChars <= 0 {
|
||||
opts.MaxInputChars = embed.DefaultMaxInputChars
|
||||
}
|
||||
if opts.Now == nil {
|
||||
opts.Now = func() time.Time { return time.Now().UTC() }
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func sameEmbeddingIdentity(job embeddingJob, opts EmbeddingDrainOptions) bool {
|
||||
return job.Provider == opts.Provider && job.Model == opts.Model && job.InputVersion == opts.InputVersion
|
||||
}
|
||||
|
||||
func emptyEmbeddingIdentity(job embeddingJob) bool {
|
||||
return job.Provider == "" && job.Model == "" && job.InputVersion == ""
|
||||
}
|
||||
|
||||
func (s *Store) pendingEmbeddingJobs(ctx context.Context, limit int, staleBefore string) ([]embeddingJob, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
select
|
||||
j.message_id,
|
||||
m.normalized_content,
|
||||
j.attempts,
|
||||
j.provider,
|
||||
j.model,
|
||||
j.input_version
|
||||
from embedding_jobs j
|
||||
join messages m on m.id = j.message_id
|
||||
where j.state = 'pending'
|
||||
and (j.locked_at is null or j.locked_at = '' or j.locked_at < ?)
|
||||
order by j.updated_at, j.message_id
|
||||
limit ?
|
||||
`, staleBefore, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
var jobs []embeddingJob
|
||||
for rows.Next() {
|
||||
var job embeddingJob
|
||||
if err := rows.Scan(&job.MessageID, &job.NormalizedContent, &job.Attempts, &job.Provider, &job.Model, &job.InputVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
return jobs, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) resetEmbeddingJobIdentity(ctx context.Context, messageID string, opts EmbeddingDrainOptions, resetAttempts bool) error {
|
||||
if resetAttempts {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
attempts = 0,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) processEmbeddingBatch(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions, jobs []embeddingJob, stats *EmbeddingDrainStats) (bool, error) {
|
||||
now := opts.Now()
|
||||
lockedAt := now.Format(timeLayout)
|
||||
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
|
||||
claimed, err := s.lockEmbeddingJobs(ctx, jobs, lockedAt, staleBefore)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(claimed) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
jobs = claimed
|
||||
inputs := make([]string, 0, len(jobs))
|
||||
for _, job := range jobs {
|
||||
inputs = append(inputs, capRunes(job.NormalizedContent, opts.MaxInputChars))
|
||||
}
|
||||
batch, err := provider.Embed(ctx, inputs)
|
||||
if err != nil {
|
||||
if embed.IsRateLimitError(err) {
|
||||
if markErr := s.markEmbeddingJobsRateLimited(ctx, opts, jobs, err); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
stats.Requeued += len(jobs)
|
||||
return true, nil
|
||||
}
|
||||
if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
stats.Processed += len(jobs)
|
||||
stats.Failed += len(jobs)
|
||||
return embed.IsRateLimitError(err), nil
|
||||
}
|
||||
dimensions, err := validateEmbeddingBatch(batch, len(jobs))
|
||||
if err != nil {
|
||||
if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
stats.Processed += len(jobs)
|
||||
stats.Failed += len(jobs)
|
||||
return false, nil
|
||||
}
|
||||
if err := s.storeEmbeddingBatch(ctx, opts, jobs, batch.Vectors, dimensions); err != nil {
|
||||
return false, err
|
||||
}
|
||||
stats.Processed += len(jobs)
|
||||
stats.Succeeded += len(jobs)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *Store) lockEmbeddingJobs(ctx context.Context, jobs []embeddingJob, lockedAt, staleBefore string) ([]embeddingJob, error) {
|
||||
if len(jobs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rollback(tx)
|
||||
claimed := make([]embeddingJob, 0, len(jobs))
|
||||
for _, job := range jobs {
|
||||
result, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set locked_at = ?, updated_at = ?
|
||||
where message_id = ?
|
||||
and state = 'pending'
|
||||
and (locked_at is null or locked_at = '' or locked_at < ?)
|
||||
`, lockedAt, lockedAt, job.MessageID, staleBefore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows == 1 {
|
||||
claimed = append(claimed, job)
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
func validateEmbeddingBatch(batch embed.EmbeddingBatch, expected int) (int, error) {
|
||||
if len(batch.Vectors) != expected {
|
||||
return 0, fmt.Errorf("embedding provider returned %d vectors for %d inputs", len(batch.Vectors), expected)
|
||||
}
|
||||
dimensions := batch.Dimensions
|
||||
for _, vector := range batch.Vectors {
|
||||
if len(vector) == 0 {
|
||||
return 0, errors.New("embedding provider returned an empty vector")
|
||||
}
|
||||
if dimensions == 0 {
|
||||
dimensions = len(vector)
|
||||
continue
|
||||
}
|
||||
if len(vector) != dimensions {
|
||||
return 0, fmt.Errorf("embedding provider dimensions mismatch: got %d want %d", len(vector), dimensions)
|
||||
}
|
||||
}
|
||||
return dimensions, nil
|
||||
}
|
||||
|
||||
func (s *Store) storeEmbeddingBatch(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, vectors [][]float32, dimensions int) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rollback(tx)
|
||||
embeddedAt := opts.Now().Format(timeLayout)
|
||||
for i, job := range jobs {
|
||||
blob, err := EncodeEmbeddingVector(vectors[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
insert into message_embeddings(
|
||||
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
) values(?, ?, ?, ?, ?, ?, ?)
|
||||
on conflict(message_id, provider, model, input_version) do update set
|
||||
dimensions = excluded.dimensions,
|
||||
embedding_blob = excluded.embedding_blob,
|
||||
embedded_at = excluded.embedded_at
|
||||
`, job.MessageID, opts.Provider, opts.Model, opts.InputVersion, dimensions, blob, embeddedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set state = 'done',
|
||||
attempts = 0,
|
||||
provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, embeddedAt, job.MessageID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) markEmbeddingJobsDone(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rollback(tx)
|
||||
now := opts.Now().Format(timeLayout)
|
||||
for _, job := range jobs {
|
||||
if _, err := tx.ExecContext(ctx, `delete from message_embeddings where message_id = ?`, job.MessageID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set state = 'done',
|
||||
provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, now, job.MessageID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) markEmbeddingJobsRateLimited(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, cause error) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rollback(tx)
|
||||
now := opts.Now().Format(timeLayout)
|
||||
lastError := trimStoredError(cause)
|
||||
for _, job := range jobs {
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set state = 'pending',
|
||||
provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = ?,
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, lastError, now, job.MessageID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) markEmbeddingJobsFailed(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, cause error) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rollback(tx)
|
||||
now := opts.Now().Format(timeLayout)
|
||||
lastError := trimStoredError(cause)
|
||||
for _, job := range jobs {
|
||||
attempts := job.Attempts + 1
|
||||
state := "pending"
|
||||
if attempts >= maxEmbeddingAttempts {
|
||||
state = "failed"
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set state = ?,
|
||||
attempts = ?,
|
||||
provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = ?,
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id = ?
|
||||
`, state, attempts, opts.Provider, opts.Model, opts.InputVersion, lastError, now, job.MessageID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func trimStoredError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
msg := strings.TrimSpace(err.Error())
|
||||
runes := []rune(msg)
|
||||
if len(runes) > maxStoredErrorChars {
|
||||
msg = string(runes[:maxStoredErrorChars])
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func capRunes(value string, maxChars int) string {
|
||||
if maxChars <= 0 {
|
||||
return value
|
||||
}
|
||||
runes := []rune(value)
|
||||
if len(runes) <= maxChars {
|
||||
return value
|
||||
}
|
||||
return string(runes[:maxChars])
|
||||
}
|
||||
|
||||
func EncodeEmbeddingVector(vector []float32) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4))
|
||||
for _, value := range vector {
|
||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
||||
return nil, fmt.Errorf("encode embedding vector: %w", err)
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func DecodeEmbeddingVector(blob []byte) ([]float32, error) {
|
||||
if len(blob)%4 != 0 {
|
||||
return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob))
|
||||
}
|
||||
out := make([]float32, len(blob)/4)
|
||||
reader := bytes.NewReader(blob)
|
||||
for i := range out {
|
||||
if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil {
|
||||
return nil, fmt.Errorf("decode embedding vector: %w", err)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRowContext(ctx, `select count(*) from embedding_jobs where state = 'pending'`).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *Store) RequeueAllEmbeddingJobs(ctx context.Context, opts EmbeddingDrainOptions) (int, error) {
|
||||
opts = normalizeEmbeddingDrainOptions(opts)
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rollback(tx)
|
||||
now := opts.Now().Format(timeLayout)
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
insert or ignore into embedding_jobs(
|
||||
message_id, state, attempts, provider, model, input_version, last_error, locked_at, updated_at
|
||||
)
|
||||
select id, 'pending', 0, ?, ?, ?, '', null, ?
|
||||
from messages
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, now); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx, `
|
||||
update embedding_jobs
|
||||
set state = 'pending',
|
||||
attempts = 0,
|
||||
provider = ?,
|
||||
model = ?,
|
||||
input_version = ?,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = ?
|
||||
where message_id in (select id from messages)
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, now)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(affected), nil
|
||||
}
|
||||
@ -211,6 +211,7 @@ func (s *Store) migrate(ctx context.Context) error {
|
||||
if err := s.setSchemaVersion(ctx, storeSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
currentVersion = storeSchemaVersion
|
||||
}
|
||||
if version, err := s.schemaVersion(ctx); err != nil {
|
||||
return err
|
||||
@ -368,8 +369,23 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
|
||||
message_id text primary key,
|
||||
state text not null,
|
||||
attempts integer not null default 0,
|
||||
provider text not null default '',
|
||||
model text not null default '',
|
||||
input_version text not null default '',
|
||||
last_error text not null default '',
|
||||
locked_at text,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table if not exists message_embeddings (
|
||||
message_id text not null,
|
||||
provider text not null,
|
||||
model text not null,
|
||||
input_version text not null,
|
||||
dimensions integer not null,
|
||||
embedding_blob blob not null,
|
||||
embedded_at text not null,
|
||||
primary key (message_id, provider, model, input_version)
|
||||
);`,
|
||||
`create virtual table if not exists message_fts using fts5(
|
||||
message_id unindexed,
|
||||
guild_id unindexed,
|
||||
@ -402,6 +418,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error {
|
||||
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
|
||||
`create index if not exists idx_mentions_target on mention_events(target_type, target_id, event_at);`,
|
||||
`create index if not exists idx_mentions_author on mention_events(author_id, event_at);`,
|
||||
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := tx.ExecContext(ctx, stmt); err != nil {
|
||||
@ -417,12 +434,43 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
defer rollback(tx)
|
||||
for _, column := range []struct {
|
||||
name string
|
||||
sql string
|
||||
}{
|
||||
{"provider", `alter table embedding_jobs add column provider text not null default ''`},
|
||||
{"model", `alter table embedding_jobs add column model text not null default ''`},
|
||||
{"input_version", `alter table embedding_jobs add column input_version text not null default ''`},
|
||||
{"last_error", `alter table embedding_jobs add column last_error text not null default ''`},
|
||||
{"locked_at", `alter table embedding_jobs add column locked_at text`},
|
||||
} {
|
||||
ok, err := columnExists(ctx, tx, "embedding_jobs", column.name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
if _, err := tx.ExecContext(ctx, column.sql); err != nil {
|
||||
return fmt.Errorf("add embedding_jobs.%s: %w", column.name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
stmts := []string{
|
||||
`create table if not exists message_embeddings (
|
||||
message_id text not null,
|
||||
provider text not null,
|
||||
model text not null,
|
||||
input_version text not null,
|
||||
dimensions integer not null,
|
||||
embedding_blob blob not null,
|
||||
embedded_at text not null,
|
||||
primary key (message_id, provider, model, input_version)
|
||||
);`,
|
||||
`create index if not exists idx_messages_guild_created_id on messages(guild_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_channel_created_id on messages(channel_id, created_at, id);`,
|
||||
`create index if not exists idx_messages_author_created_id on messages(author_id, created_at, id);`,
|
||||
`create index if not exists idx_mentions_guild_event on mention_events(guild_id, event_at, event_id);`,
|
||||
`create index if not exists idx_mentions_channel_event on mention_events(channel_id, event_at, event_id);`,
|
||||
`create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := tx.ExecContext(ctx, stmt); err != nil {
|
||||
@ -432,6 +480,27 @@ func (s *Store) applyQueryIndexMigration(ctx context.Context) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func columnExists(ctx context.Context, tx *sql.Tx, table, column string) (bool, error) {
|
||||
rows, err := tx.QueryContext(ctx, `pragma table_info(`+table+`)`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("inspect %s columns: %w", table, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, typ string
|
||||
var notNull int
|
||||
var defaultValue sql.NullString
|
||||
var pk int
|
||||
if err := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if name == column {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, rows.Err()
|
||||
}
|
||||
func (s *Store) ensureFTSRowIDs(ctx context.Context) error {
|
||||
var version sql.NullString
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
|
||||
@ -325,6 +325,88 @@ func TestOpenBackfillsMissingSchemaVersion(t *testing.T) {
|
||||
require.Equal(t, storeSchemaVersion, version)
|
||||
}
|
||||
|
||||
func TestOpenMigratesSchemaV1ToV2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
|
||||
require.NoError(t, createV1Schema(ctx, dbPath))
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
insert into messages(
|
||||
id, guild_id, channel_id, message_type, created_at, content,
|
||||
normalized_content, raw_json, updated_at
|
||||
) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z')
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
insert into embedding_jobs(message_id, state, attempts, updated_at)
|
||||
values('m1', 'pending', 1, '2026-01-01T00:00:00Z')
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
|
||||
s, err := Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
var version int
|
||||
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version))
|
||||
require.Equal(t, 2, version)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"", "", "", "", ""}}, rows)
|
||||
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", rows[0][0])
|
||||
}
|
||||
|
||||
func TestOpenMigratesUnversionedV1SchemaToV2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
dbPath := filepath.Join(t.TempDir(), "discrawl.db")
|
||||
require.NoError(t, createV1Schema(ctx, dbPath))
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
insert into messages(
|
||||
id, guild_id, channel_id, message_type, created_at, content,
|
||||
normalized_content, raw_json, updated_at
|
||||
) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z')
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(ctx, `
|
||||
insert into embedding_jobs(message_id, state, attempts, updated_at)
|
||||
values('m1', 'pending', 1, '2026-01-01T00:00:00Z')
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
_, err = db.ExecContext(ctx, `pragma user_version = 0`)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
|
||||
s, err := Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
var version int
|
||||
require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version))
|
||||
require.Equal(t, 2, version)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"", "", "", "", ""}}, rows)
|
||||
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", rows[0][0])
|
||||
}
|
||||
|
||||
func TestReadOnlyQueryGuards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -342,6 +424,142 @@ func TestReadOnlyQueryGuards(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func createV1Schema(ctx context.Context, path string) error {
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = db.Close() }()
|
||||
stmts := []string{
|
||||
`create table guilds (
|
||||
id text primary key,
|
||||
name text not null,
|
||||
icon text,
|
||||
raw_json text not null,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table channels (
|
||||
id text primary key,
|
||||
guild_id text not null,
|
||||
parent_id text,
|
||||
kind text not null,
|
||||
name text not null,
|
||||
topic text,
|
||||
position integer,
|
||||
is_nsfw integer not null default 0,
|
||||
is_archived integer not null default 0,
|
||||
is_locked integer not null default 0,
|
||||
is_private_thread integer not null default 0,
|
||||
thread_parent_id text,
|
||||
archive_timestamp text,
|
||||
raw_json text not null,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table members (
|
||||
guild_id text not null,
|
||||
user_id text not null,
|
||||
username text not null,
|
||||
global_name text,
|
||||
display_name text,
|
||||
nick text,
|
||||
discriminator text,
|
||||
avatar text,
|
||||
bot integer not null default 0,
|
||||
joined_at text,
|
||||
role_ids_json text not null,
|
||||
raw_json text not null,
|
||||
updated_at text not null,
|
||||
primary key (guild_id, user_id)
|
||||
);`,
|
||||
`create table messages (
|
||||
id text primary key,
|
||||
guild_id text not null,
|
||||
channel_id text not null,
|
||||
author_id text,
|
||||
message_type integer not null,
|
||||
created_at text not null,
|
||||
edited_at text,
|
||||
deleted_at text,
|
||||
content text not null,
|
||||
normalized_content text not null,
|
||||
reply_to_message_id text,
|
||||
pinned integer not null default 0,
|
||||
has_attachments integer not null default 0,
|
||||
raw_json text not null,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table message_events (
|
||||
event_id integer primary key autoincrement,
|
||||
guild_id text not null,
|
||||
channel_id text not null,
|
||||
message_id text not null,
|
||||
event_type text not null,
|
||||
event_at text not null,
|
||||
payload_json text not null
|
||||
);`,
|
||||
`create table message_attachments (
|
||||
attachment_id text primary key,
|
||||
message_id text not null,
|
||||
guild_id text not null,
|
||||
channel_id text not null,
|
||||
author_id text,
|
||||
filename text not null,
|
||||
content_type text,
|
||||
size integer not null default 0,
|
||||
url text,
|
||||
proxy_url text,
|
||||
text_content text not null default '',
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table mention_events (
|
||||
event_id integer primary key autoincrement,
|
||||
message_id text not null,
|
||||
guild_id text not null,
|
||||
channel_id text not null,
|
||||
author_id text,
|
||||
target_type text not null,
|
||||
target_id text not null,
|
||||
target_name text not null default '',
|
||||
event_at text not null
|
||||
);`,
|
||||
`create table sync_state (
|
||||
scope text primary key,
|
||||
cursor text,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create table embedding_jobs (
|
||||
message_id text primary key,
|
||||
state text not null,
|
||||
attempts integer not null default 0,
|
||||
updated_at text not null
|
||||
);`,
|
||||
`create virtual table message_fts using fts5(
|
||||
message_id unindexed,
|
||||
guild_id unindexed,
|
||||
channel_id unindexed,
|
||||
author_id unindexed,
|
||||
author_name,
|
||||
channel_name,
|
||||
content
|
||||
);`,
|
||||
`create virtual table member_fts using fts5(
|
||||
member_key unindexed,
|
||||
guild_id unindexed,
|
||||
user_id unindexed,
|
||||
username,
|
||||
display_name,
|
||||
profile_text
|
||||
);`,
|
||||
`pragma user_version = 1;`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestQueryAndExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -3,12 +3,15 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/steipete/discrawl/internal/embed"
|
||||
)
|
||||
|
||||
func TestUpsertMessagesBatch(t *testing.T) {
|
||||
@ -123,6 +126,387 @@ func TestUpsertMessageWithEmbeddingsQueuesJob(t *testing.T) {
|
||||
require.Equal(t, "1", rows[0][0])
|
||||
}
|
||||
|
||||
func TestUpsertMessageWithEmbeddingsQueuesExistingMessageWithoutJob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
record := MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}
|
||||
require.NoError(t, s.UpsertMessage(ctx, record))
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"pending", "0"}}, rows)
|
||||
}
|
||||
|
||||
func TestDrainEmbeddingJobsStoresVectorsAndSkipsEmptyInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: now,
|
||||
Content: "abcdef世界",
|
||||
NormalizedContent: "abcdef世界",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: now,
|
||||
Content: "",
|
||||
NormalizedContent: " ",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
provider := &fakeEmbeddingProvider{
|
||||
batches: []embed.EmbeddingBatch{{
|
||||
Vectors: [][]float32{{1.25, 2.5}},
|
||||
}},
|
||||
}
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
Limit: 10,
|
||||
BatchSize: 2,
|
||||
MaxInputChars: 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, stats.Processed)
|
||||
require.Equal(t, 1, stats.Succeeded)
|
||||
require.Equal(t, 1, stats.Skipped)
|
||||
require.Equal(t, 0, stats.RemainingBacklog)
|
||||
require.Equal(t, [][]string{{"abcdef世"}}, provider.inputs)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, provider, model, input_version, dimensions from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"m1", "ollama", "nomic-embed-text", EmbeddingInputVersion, "2"}}, rows)
|
||||
|
||||
var blob []byte
|
||||
require.NoError(t, s.DB().QueryRowContext(ctx, `select embedding_blob from message_embeddings where message_id = 'm1'`).Scan(&blob))
|
||||
vector, err := DecodeEmbeddingVector(blob)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []float32{1.25, 2.5}, vector)
|
||||
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select message_id, state, provider, model, input_version from embedding_jobs order by message_id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{
|
||||
{"m1", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
|
||||
{"m2", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion},
|
||||
}, rows)
|
||||
}
|
||||
|
||||
func TestUpsertMessageWithEmbeddingsDoesNotRequeueUnchangedDoneJob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
record := MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
|
||||
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
|
||||
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, stats.Succeeded)
|
||||
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"done", "0", ""}}, rows)
|
||||
|
||||
backlog, err := s.EmbeddingBacklog(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, backlog)
|
||||
|
||||
record.NormalizedContent = "hello updated"
|
||||
record.Content = "hello updated"
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"pending", "0", ""}}, rows)
|
||||
}
|
||||
|
||||
func TestDrainEmbeddingJobsFailsWholeBatchOnDimensionMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
|
||||
batches: []embed.EmbeddingBatch{{
|
||||
Dimensions: 3,
|
||||
Vectors: [][]float32{{1, 2}},
|
||||
}},
|
||||
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, stats.Failed)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pending", rows[0][0])
|
||||
require.Equal(t, "1", rows[0][1])
|
||||
require.Contains(t, rows[0][2], "dimensions mismatch")
|
||||
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", rows[0][0])
|
||||
}
|
||||
|
||||
func TestDrainEmbeddingJobsMarksFailedAfterMaxAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
_, err = s.DB().ExecContext(ctx, `update embedding_jobs set attempts = 2 where message_id = 'm1'`)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{err: errors.New("provider down")}, EmbeddingDrainOptions{
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
Limit: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, stats.Failed)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"failed", "3", "provider down"}}, rows)
|
||||
}
|
||||
|
||||
func TestDrainEmbeddingJobsStopsOnRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
for _, id := range []string{"m1", "m2"} {
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: id,
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
}
|
||||
|
||||
provider := &fakeEmbeddingProvider{err: &embed.HTTPError{StatusCode: 429, Body: "slow down"}}
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
Limit: 10,
|
||||
BatchSize: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stats.RateLimited)
|
||||
require.Equal(t, 0, stats.Processed)
|
||||
require.Equal(t, 0, stats.Failed)
|
||||
require.Equal(t, 1, stats.Requeued)
|
||||
require.Equal(t, 2, stats.RemainingBacklog)
|
||||
require.Len(t, provider.inputs, 1)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, last_error, coalesce(locked_at, '') from embedding_jobs order by message_id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{
|
||||
{"m1", "pending", "0", "embedding request failed with HTTP 429: slow down", ""},
|
||||
{"m2", "pending", "0", "", ""},
|
||||
}, rows)
|
||||
}
|
||||
|
||||
func TestDrainEmbeddingJobsDeletesStaleVectorsForEmptyContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
record := MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
_, err = s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{
|
||||
batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}},
|
||||
}, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
record.Content = ""
|
||||
record.NormalizedContent = " "
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{}, EmbeddingDrainOptions{
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
Limit: 10,
|
||||
BatchSize: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, stats.Processed)
|
||||
require.Equal(t, 1, stats.Skipped)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", rows[0][0])
|
||||
|
||||
_, rows, err = s.ReadOnlyQuery(ctx, "select state, provider, model from embedding_jobs where message_id = 'm1'")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{{"done", "ollama", "nomic-embed-text"}}, rows)
|
||||
}
|
||||
|
||||
func TestPendingEmbeddingJobsSkipsFreshLocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}, WriteOptions{EnqueueEmbedding: true}))
|
||||
|
||||
now := time.Now().UTC()
|
||||
staleBefore := now.Add(-embeddingLockTimeout).Format(timeLayout)
|
||||
jobs, err := s.pendingEmbeddingJobs(ctx, 10, staleBefore)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, jobs, 1)
|
||||
|
||||
claimed, err := s.lockEmbeddingJobs(ctx, jobs, now.Format(timeLayout), staleBefore)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, claimed, 1)
|
||||
|
||||
jobs, err = s.pendingEmbeddingJobs(ctx, 10, staleBefore)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, jobs)
|
||||
|
||||
claimed, err = s.lockEmbeddingJobs(ctx, claimed, now.Add(time.Minute).Format(timeLayout), staleBefore)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, claimed)
|
||||
}
|
||||
|
||||
func TestRequeueAllEmbeddingJobsUsesCurrentIdentity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
for _, id := range []string{"m1", "m2"} {
|
||||
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
|
||||
ID: id,
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
}
|
||||
_, err = s.DB().ExecContext(ctx, `
|
||||
insert into embedding_jobs(message_id, state, attempts, provider, model, input_version, last_error, updated_at)
|
||||
values('m1', 'failed', 3, 'old', 'old-model', 'old-input', 'old error', ?)
|
||||
`, time.Now().UTC().Format(timeLayout))
|
||||
require.NoError(t, err)
|
||||
|
||||
requeued, err := s.RequeueAllEmbeddingJobs(ctx, EmbeddingDrainOptions{
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, requeued)
|
||||
|
||||
_, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, provider, model, input_version, last_error from embedding_jobs order by message_id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [][]string{
|
||||
{"m1", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
|
||||
{"m2", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""},
|
||||
}, rows)
|
||||
}
|
||||
|
||||
func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -160,6 +544,26 @@ func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) {
|
||||
require.Equal(t, "8", rows[0][0])
|
||||
}
|
||||
|
||||
type fakeEmbeddingProvider struct {
|
||||
batches []embed.EmbeddingBatch
|
||||
err error
|
||||
inputs [][]string
|
||||
}
|
||||
|
||||
func (f *fakeEmbeddingProvider) Embed(_ context.Context, inputs []string) (embed.EmbeddingBatch, error) {
|
||||
copied := append([]string(nil), inputs...)
|
||||
f.inputs = append(f.inputs, copied)
|
||||
if f.err != nil {
|
||||
return embed.EmbeddingBatch{}, f.err
|
||||
}
|
||||
if len(f.batches) == 0 {
|
||||
return embed.EmbeddingBatch{}, nil
|
||||
}
|
||||
batch := f.batches[0]
|
||||
f.batches = f.batches[1:]
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
func TestMessageFTSUsesSnowflakeRowID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -287,6 +287,30 @@ func (s *Store) UpsertMessages(ctx context.Context, messages []MessageMutation)
|
||||
|
||||
func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opts WriteOptions) error {
|
||||
now := time.Now().UTC().Format(timeLayout)
|
||||
var previousNormalized sql.NullString
|
||||
previousErr := sql.ErrNoRows
|
||||
jobExists := false
|
||||
if opts.EnqueueEmbedding {
|
||||
previousErr = tx.QueryRowContext(ctx, `
|
||||
select normalized_content
|
||||
from messages
|
||||
where id = ?
|
||||
`, message.ID).Scan(&previousNormalized)
|
||||
if previousErr != nil && previousErr != sql.ErrNoRows {
|
||||
return previousErr
|
||||
}
|
||||
if previousErr == nil {
|
||||
var existingJobs int
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
select count(*)
|
||||
from embedding_jobs
|
||||
where message_id = ?
|
||||
`, message.ID).Scan(&existingJobs); err != nil {
|
||||
return err
|
||||
}
|
||||
jobExists = existingJobs > 0
|
||||
}
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
insert into messages(
|
||||
id, guild_id, channel_id, author_id, message_type, created_at, edited_at, deleted_at,
|
||||
@ -323,11 +347,17 @@ func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opt
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.EnqueueEmbedding {
|
||||
queueEmbedding := opts.EnqueueEmbedding && (previousErr == sql.ErrNoRows || previousNormalized.String != message.NormalizedContent || !jobExists)
|
||||
if queueEmbedding {
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
insert into embedding_jobs(message_id, state, attempts, updated_at)
|
||||
values(?, 'pending', 0, ?)
|
||||
on conflict(message_id) do nothing
|
||||
on conflict(message_id) do update set
|
||||
state = 'pending',
|
||||
attempts = 0,
|
||||
last_error = '',
|
||||
locked_at = null,
|
||||
updated_at = excluded.updated_at
|
||||
`, message.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user