diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 6449e38..050af47 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -555,6 +555,139 @@ func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) { require.Contains(t, out.String(), "requeued=2") } +func TestSearchSemanticCommandUsesStoredEmbeddings(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + dbPath := filepath.Join(dir, "discrawl.db") + + var requests int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + require.Equal(t, "/embeddings", r.URL.Path) + var req struct { + Model string `json:"model"` + Input []string `json:"input"` + } + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Equal(t, "local-model", req.Model) + require.Equal(t, []string{"cats"}, req.Input) + _, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`)) + })) + defer server.Close() + + cfg := config.Default() + cfg.DBPath = dbPath + cfg.Search.DefaultMode = "semantic" + cfg.Search.Embeddings.Enabled = true + cfg.Search.Embeddings.Provider = "openai_compatible" + cfg.Search.Embeddings.Model = "local-model" + cfg.Search.Embeddings.BaseURL = server.URL + cfg.Search.Embeddings.APIKeyEnv = "" + require.NoError(t, config.Write(cfgPath, cfg)) + + s, err := store.Open(ctx, dbPath) + require.NoError(t, err) + base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC) + require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`})) + require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Alice", + MessageType: 0, + CreatedAt: base.Format(time.RFC3339Nano), + Content: "database migration discussion", + NormalizedContent: "database migration discussion", + RawJSON: `{"author":{"username":"Alice"}}`, + })) + require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{ + ID: "m2", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u2", + AuthorName: "Bob", + MessageType: 0, + CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano), + Content: "cats in semantic search", + NormalizedContent: "cats in semantic search", + RawJSON: `{"author":{"username":"Bob"}}`, + })) + require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{1, 0})) + require.NoError(t, insertCLIEmbedding(ctx, s, "m2", "openai_compatible", "local-model", []float32{0.8, 0.2})) + require.NoError(t, s.Close()) + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--limit", "1", "cats"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "database migration discussion") + require.NotContains(t, out.String(), "cats in semantic search") + require.Equal(t, 1, requests) + + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "--channel", "general", "--author", "Alice", "cats"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "database migration discussion") + require.NotContains(t, out.String(), "cats in semantic search") + require.Equal(t, 2, requests) +} + +func TestSearchSemanticCommandErrors(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + dbPath := filepath.Join(dir, "discrawl.db") + + cfg := config.Default() + cfg.DBPath = dbPath + require.NoError(t, config.Write(cfgPath, cfg)) + s, err := store.Open(ctx, dbPath) + require.NoError(t, err) + require.NoError(t, s.Close()) + + err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "bogus", "cats"}, &bytes.Buffer{}, &bytes.Buffer{}) + require.Equal(t, 2, ExitCode(err)) + require.ErrorContains(t, err, `unsupported search mode "bogus"`) + + err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "cats"}, &bytes.Buffer{}, &bytes.Buffer{}) + require.Equal(t, 1, ExitCode(err)) + require.ErrorContains(t, err, "hybrid search is not implemented yet") + + err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{}) + require.Equal(t, 1, ExitCode(err)) + require.ErrorContains(t, err, "embeddings are disabled") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "nope", http.StatusInternalServerError) + })) + defer server.Close() + cfg.Search.Embeddings.Enabled = true + cfg.Search.Embeddings.Provider = "openai_compatible" + cfg.Search.Embeddings.Model = "local-model" + cfg.Search.Embeddings.BaseURL = server.URL + cfg.Search.Embeddings.APIKeyEnv = "" + require.NoError(t, config.Write(cfgPath, cfg)) + + err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{}) + require.Equal(t, 1, ExitCode(err)) + require.ErrorContains(t, err, "embedding query failed") +} + +func insertCLIEmbedding(ctx context.Context, s *store.Store, messageID, provider, model string, vector []float32) error { + blob, err := store.EncodeEmbeddingVector(vector) + if err != nil { + return err + } + _, err = s.DB().ExecContext(ctx, ` + insert into message_embeddings( + message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + ) values(?, ?, ?, ?, ?, ?, ?) + `, messageID, provider, model, store.EmbeddingInputVersion, len(vector), blob, time.Now().UTC().Format(time.RFC3339Nano)) + return err +} + type fakeDiscordClient struct { guilds []*discordgo.UserGuild self *discordgo.User diff --git a/internal/cli/query_commands.go b/internal/cli/query_commands.go index 4f0d637..6e3b6e8 100644 --- a/internal/cli/query_commands.go +++ b/internal/cli/query_commands.go @@ -8,6 +8,8 @@ import ( "os" "strings" + "github.com/steipete/discrawl/internal/config" + "github.com/steipete/discrawl/internal/embed" "github.com/steipete/discrawl/internal/store" ) @@ -27,19 +29,72 @@ func (r *runtime) runSearch(args []string) error { if fs.NArg() != 1 { return usageErr(fmt.Errorf("search requires a query")) } - _ = mode - results, err := r.store.SearchMessages(r.ctx, store.SearchOptions{ + opts := store.SearchOptions{ Query: fs.Arg(0), GuildIDs: r.resolveSearchGuilds(*guildFlag, *guildsFlag), Channel: *channel, Author: *author, Limit: *limit, IncludeEmpty: *includeEmpty, - }) - if err != nil { - return err } - return r.print(results) + switch strings.ToLower(strings.TrimSpace(*mode)) { + case "", "fts": + results, err := r.store.SearchMessages(r.ctx, opts) + if err != nil { + return err + } + return r.print(results) + case "semantic": + results, err := r.searchMessagesSemantic(opts) + if err != nil { + return err + } + return r.print(results) + case "hybrid": + return fmt.Errorf("hybrid search is not implemented yet") + default: + return usageErr(fmt.Errorf("unsupported search mode %q", *mode)) + } +} + +func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) { + if !r.cfg.Search.Embeddings.Enabled { + return nil, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first") + } + providerFactory := r.newEmbed + if providerFactory == nil { + providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) { + return embed.NewProvider(cfg) + } + } + provider, err := providerFactory(r.cfg.Search.Embeddings) + if err != nil { + return nil, fmt.Errorf("create embedding provider: %w", err) + } + batch, err := provider.Embed(r.ctx, []string{opts.Query}) + if err != nil { + return nil, fmt.Errorf("embedding query failed: %w", err) + } + if len(batch.Vectors) != 1 { + return nil, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors)) + } + queryVector := batch.Vectors[0] + dimensions := batch.Dimensions + if dimensions == 0 { + dimensions = len(queryVector) + } + return r.store.SearchMessagesSemantic(r.ctx, store.SemanticSearchOptions{ + QueryVector: queryVector, + Provider: r.cfg.Search.Embeddings.Provider, + Model: r.cfg.Search.Embeddings.Model, + InputVersion: store.EmbeddingInputVersion, + Dimensions: dimensions, + GuildIDs: opts.GuildIDs, + Channel: opts.Channel, + Author: opts.Author, + Limit: opts.Limit, + IncludeEmpty: opts.IncludeEmpty, + }) } func (r *runtime) runSQL(args []string) error { diff --git a/internal/store/query.go b/internal/store/query.go index cdbf65a..0881f3e 100644 --- a/internal/store/query.go +++ b/internal/store/query.go @@ -5,7 +5,9 @@ import ( "database/sql" "errors" "fmt" + "math" "os" + "sort" "strings" "time" ) @@ -19,6 +21,21 @@ const ( messageFTSHealthProbe = "__discrawl_probe__" ) +var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild") + +type SemanticSearchOptions struct { + QueryVector []float32 + Provider string + Model string + InputVersion string + Dimensions int + GuildIDs []string + Channel string + Author string + Limit int + IncludeEmpty bool +} + func (s *Store) GetSyncState(ctx context.Context, scope string) (string, error) { var cursor sql.NullString err := s.db.QueryRowContext(ctx, `select cursor from sync_state where scope = ?`, scope).Scan(&cursor) @@ -120,6 +137,172 @@ func (s *Store) SearchMessages(ctx context.Context, opts SearchOptions) ([]Searc return out, rows.Err() } +func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchOptions) ([]SearchResult, error) { + opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider)) + opts.Model = strings.TrimSpace(opts.Model) + opts.InputVersion = strings.TrimSpace(opts.InputVersion) + if opts.InputVersion == "" { + opts.InputVersion = EmbeddingInputVersion + } + if opts.Limit <= 0 { + opts.Limit = 20 + } + if len(opts.QueryVector) == 0 { + return nil, errors.New("semantic query embedding returned an empty vector") + } + if opts.Dimensions <= 0 { + opts.Dimensions = len(opts.QueryVector) + } + if len(opts.QueryVector) != opts.Dimensions { + return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions) + } + queryNorm := vectorNorm(opts.QueryVector) + if queryNorm == 0 { + return nil, errors.New("semantic query embedding returned a zero vector") + } + + clauses := []string{ + "e.provider = ?", + "e.model = ?", + "e.input_version = ?", + "e.dimensions = ?", + } + args := []any{opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions} + if len(opts.GuildIDs) > 0 { + clauses = append(clauses, "m.guild_id in ("+placeholders(len(opts.GuildIDs))+")") + for _, guildID := range opts.GuildIDs { + args = append(args, guildID) + } + } + if strings.TrimSpace(opts.Channel) != "" { + clauses = append(clauses, "(m.channel_id = ? or c.name like ?)") + args = append(args, opts.Channel, "%"+opts.Channel+"%") + } + authorExpr := `coalesce( + json_extract(m.raw_json, '$.member.nick'), + json_extract(m.raw_json, '$.author.global_name'), + json_extract(m.raw_json, '$.author.username'), + '' + )` + if strings.TrimSpace(opts.Author) != "" { + clauses = append(clauses, "(m.author_id = ? or "+authorExpr+" like ?)") + args = append(args, opts.Author, "%"+opts.Author+"%") + } + if !opts.IncludeEmpty { + clauses = append(clauses, "trim(coalesce(m.normalized_content, '')) <> ''") + } + args = append(args, searchCandidateLimit(opts.Limit)) + + queryCtx, cancel := withQueryTimeout(ctx) + defer cancel() + rows, err := s.db.QueryContext(queryCtx, ` + select + m.id, + m.guild_id, + m.channel_id, + coalesce(c.name, ''), + coalesce(m.author_id, ''), + `+authorExpr+`, + case + when trim(coalesce(m.content, '')) <> '' then m.content + else m.normalized_content + end, + m.created_at, + e.dimensions, + e.embedding_blob + from message_embeddings e + join messages m on m.id = e.message_id + left join channels c on c.id = m.channel_id + where `+strings.Join(clauses, " and ")+` + order by m.created_at desc, m.id desc + limit ? + `, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + type scoredResult struct { + result SearchResult + score float64 + } + var scored []scoredResult + for rows.Next() { + var ( + row SearchResult + created string + dimensions int + blob []byte + ) + if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created, &dimensions, &blob); err != nil { + return nil, err + } + if dimensions != opts.Dimensions { + return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions) + } + vector, err := DecodeEmbeddingVector(blob) + if err != nil { + return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err) + } + if len(vector) != dimensions { + return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions) + } + score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector) + if err != nil { + return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err) + } + row.CreatedAt = parseTime(created) + scored = append(scored, scoredResult{result: row, score: score}) + } + if err := rows.Err(); err != nil { + return nil, err + } + if len(scored) == 0 { + compatible, err := s.hasCompatibleMessageEmbeddings(ctx, opts) + if err != nil { + return nil, err + } + if !compatible { + return nil, ErrNoCompatibleEmbeddings + } + return []SearchResult{}, nil + } + sort.SliceStable(scored, func(i, j int) bool { + if scored[i].score != scored[j].score { + return scored[i].score > scored[j].score + } + if !scored[i].result.CreatedAt.Equal(scored[j].result.CreatedAt) { + return scored[i].result.CreatedAt.After(scored[j].result.CreatedAt) + } + return scored[i].result.MessageID > scored[j].result.MessageID + }) + if len(scored) > opts.Limit { + scored = scored[:opts.Limit] + } + out := make([]SearchResult, 0, len(scored)) + for _, item := range scored { + out = append(out, item.result) + } + return out, nil +} + +func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) { + queryCtx, cancel := withQueryTimeout(ctx) + defer cancel() + var exists int + err := s.db.QueryRowContext(queryCtx, ` + select exists( + select 1 + from message_embeddings + where provider = ? + and model = ? + and input_version = ? + and dimensions = ? + ) + `, opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions).Scan(&exists) + return exists == 1, err +} + func (s *Store) CheckMessageFTS(ctx context.Context) error { db, cleanup, err := s.openReadOnlyDB() if err != nil { @@ -200,6 +383,29 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc return out, rows.Err() } +func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) { + if len(vector) != len(query) { + return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query)) + } + vectorNorm := vectorNorm(vector) + if vectorNorm == 0 { + return 0, errors.New("stored embedding vector is zero") + } + var dot float64 + for i := range query { + dot += float64(query[i]) * float64(vector[i]) + } + return dot / (queryNorm * vectorNorm), nil +} + +func vectorNorm(vector []float32) float64 { + var sum float64 + for _, value := range vector { + sum += float64(value) * float64(value) + } + return math.Sqrt(sum) +} + func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) { if strings.TrimSpace(query) != "" { return s.searchMembers(ctx, guildID, query, limit) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index b38103a..bdf1216 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -185,6 +185,238 @@ func TestSearchMessagesPrefersRecentMessageIDs(t *testing.T) { require.Contains(t, results[0].Content, "newest") } +func TestSearchMessagesSemanticRanksAndFilters(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC) + require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`})) + require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g2", Name: "Other", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c2", GuildID: "g1", Kind: "text", Name: "random", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c3", GuildID: "g2", Kind: "text", Name: "other", RawJSON: `{}`})) + + semanticMessages := []MessageRecord{ + { + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + MessageType: 0, + CreatedAt: base.Format(time.RFC3339Nano), + Content: "cats and databases", + NormalizedContent: "cats and databases", + RawJSON: `{"author":{"username":"Alice"}}`, + }, + { + ID: "m2", + GuildID: "g1", + ChannelID: "c2", + ChannelName: "random", + AuthorID: "u2", + MessageType: 0, + CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano), + Content: "cats but weaker", + NormalizedContent: "cats but weaker", + RawJSON: `{"author":{"username":"Bob"}}`, + }, + { + ID: "m3", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + MessageType: 0, + CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano), + Content: "dogs", + NormalizedContent: "dogs", + RawJSON: `{"author":{"username":"Alice"}}`, + }, + { + ID: "m4", + GuildID: "g2", + ChannelID: "c3", + ChannelName: "other", + AuthorID: "u3", + MessageType: 0, + CreatedAt: base.Add(3 * time.Minute).Format(time.RFC3339Nano), + Content: "other guild cats", + NormalizedContent: "other guild cats", + RawJSON: `{"author":{"username":"Carol"}}`, + }, + { + ID: "m5", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u4", + MessageType: 0, + CreatedAt: base.Add(4 * time.Minute).Format(time.RFC3339Nano), + Content: "", + NormalizedContent: "", + RawJSON: `{"author":{"username":"Empty"}}`, + }, + } + for _, message := range semanticMessages { + require.NoError(t, s.UpsertMessage(ctx, message)) + } + require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{1, 0})) + require.NoError(t, insertTestEmbedding(ctx, s, "m2", "ollama", "nomic-embed-text", []float32{0.9, 0.1})) + require.NoError(t, insertTestEmbedding(ctx, s, "m3", "ollama", "nomic-embed-text", []float32{0, 1})) + require.NoError(t, insertTestEmbedding(ctx, s, "m4", "ollama", "nomic-embed-text", []float32{1, 0})) + require.NoError(t, insertTestEmbedding(ctx, s, "m5", "ollama", "nomic-embed-text", []float32{1, 0})) + + results, err := s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + GuildIDs: []string{"g1"}, + Limit: 3, + }) + require.NoError(t, err) + require.Equal(t, []string{"m1", "m2", "m3"}, searchResultIDs(results)) + + results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + GuildIDs: []string{"g1"}, + Channel: "general", + Author: "Alice", + Limit: 10, + }) + require.NoError(t, err) + require.Equal(t, []string{"m1", "m3"}, searchResultIDs(results)) + require.Equal(t, "Alice", results[0].AuthorName) + require.Equal(t, "general", results[0].ChannelName) + + results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + GuildIDs: []string{"g1"}, + Channel: "general", + Limit: 10, + IncludeEmpty: true, + }) + require.NoError(t, err) + require.Equal(t, []string{"m5", "m1", "m3"}, searchResultIDs(results)) + + results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + GuildIDs: []string{"g1"}, + Channel: "missing-channel", + Limit: 10, + }) + require.NoError(t, err) + require.Empty(t, results) +} + +func TestSearchMessagesSemanticErrors(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, s.UpsertMessage(ctx, MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + })) + + _, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "missing-model", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + Limit: 10, + }) + require.ErrorIs(t, err, ErrNoCompatibleEmbeddings) + + _, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{0, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + Dimensions: 2, + Limit: 10, + }) + require.ErrorContains(t, err, "zero vector") + + require.NoError(t, insertTestEmbeddingBlob(ctx, s, "m1", "ollama", "nomic-embed-text", 2, []byte{0, 0, 0, 0})) + _, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + Limit: 10, + }) + require.ErrorContains(t, err, "vector length mismatch") + + require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{0, 0})) + _, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{ + QueryVector: []float32{1, 0}, + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + Dimensions: 2, + Limit: 10, + }) + require.ErrorContains(t, err, "stored embedding vector is zero") +} + +func insertTestEmbedding(ctx context.Context, s *Store, messageID, provider, model string, vector []float32) error { + blob, err := EncodeEmbeddingVector(vector) + if err != nil { + return err + } + return insertTestEmbeddingBlob(ctx, s, messageID, provider, model, len(vector), blob) +} + +func insertTestEmbeddingBlob(ctx context.Context, s *Store, messageID, provider, model string, dimensions int, blob []byte) error { + _, err := s.DB().ExecContext(ctx, ` + insert into message_embeddings( + message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + ) values(?, ?, ?, ?, ?, ?, ?) + on conflict(message_id, provider, model, input_version) do update set + dimensions = excluded.dimensions, + embedding_blob = excluded.embedding_blob, + embedded_at = excluded.embedded_at + `, messageID, provider, model, EmbeddingInputVersion, dimensions, blob, time.Now().UTC().Format(timeLayout)) + return err +} + +func searchResultIDs(results []SearchResult) []string { + ids := make([]string, 0, len(results)) + for _, result := range results { + ids = append(ids, result.MessageID) + } + return ids +} + func TestCheckMessageFTSProbe(t *testing.T) { t.Parallel()