feat(search): add semantic message search

This commit is contained in:
MrBrain 2026-04-22 14:23:35 +08:00 committed by Peter Steinberger
parent ed929a92eb
commit 3fc8defc35
4 changed files with 632 additions and 6 deletions

View File

@ -555,6 +555,139 @@ func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
require.Contains(t, out.String(), "requeued=2")
}
func TestSearchSemanticCommandUsesStoredEmbeddings(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
var requests int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests++
require.Equal(t, "/embeddings", r.URL.Path)
var req struct {
Model string `json:"model"`
Input []string `json:"input"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.Equal(t, "local-model", req.Model)
require.Equal(t, []string{"cats"}, req.Input)
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
}))
defer server.Close()
cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.DefaultMode = "semantic"
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
AuthorName: "Alice",
MessageType: 0,
CreatedAt: base.Format(time.RFC3339Nano),
Content: "database migration discussion",
NormalizedContent: "database migration discussion",
RawJSON: `{"author":{"username":"Alice"}}`,
}))
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
ID: "m2",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u2",
AuthorName: "Bob",
MessageType: 0,
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
Content: "cats in semantic search",
NormalizedContent: "cats in semantic search",
RawJSON: `{"author":{"username":"Bob"}}`,
}))
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{1, 0}))
require.NoError(t, insertCLIEmbedding(ctx, s, "m2", "openai_compatible", "local-model", []float32{0.8, 0.2}))
require.NoError(t, s.Close())
var out bytes.Buffer
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--limit", "1", "cats"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "database migration discussion")
require.NotContains(t, out.String(), "cats in semantic search")
require.Equal(t, 1, requests)
out.Reset()
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "--channel", "general", "--author", "Alice", "cats"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "database migration discussion")
require.NotContains(t, out.String(), "cats in semantic search")
require.Equal(t, 2, requests)
}
func TestSearchSemanticCommandErrors(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")
cfg := config.Default()
cfg.DBPath = dbPath
require.NoError(t, config.Write(cfgPath, cfg))
s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
require.NoError(t, s.Close())
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "bogus", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 2, ExitCode(err))
require.ErrorContains(t, err, `unsupported search mode "bogus"`)
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 1, ExitCode(err))
require.ErrorContains(t, err, "hybrid search is not implemented yet")
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 1, ExitCode(err))
require.ErrorContains(t, err, "embeddings are disabled")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "nope", http.StatusInternalServerError)
}))
defer server.Close()
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
require.Equal(t, 1, ExitCode(err))
require.ErrorContains(t, err, "embedding query failed")
}
func insertCLIEmbedding(ctx context.Context, s *store.Store, messageID, provider, model string, vector []float32) error {
blob, err := store.EncodeEmbeddingVector(vector)
if err != nil {
return err
}
_, err = s.DB().ExecContext(ctx, `
insert into message_embeddings(
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
) values(?, ?, ?, ?, ?, ?, ?)
`, messageID, provider, model, store.EmbeddingInputVersion, len(vector), blob, time.Now().UTC().Format(time.RFC3339Nano))
return err
}
type fakeDiscordClient struct {
guilds []*discordgo.UserGuild
self *discordgo.User

View File

@ -8,6 +8,8 @@ import (
"os"
"strings"
"github.com/steipete/discrawl/internal/config"
"github.com/steipete/discrawl/internal/embed"
"github.com/steipete/discrawl/internal/store"
)
@ -27,19 +29,72 @@ func (r *runtime) runSearch(args []string) error {
if fs.NArg() != 1 {
return usageErr(fmt.Errorf("search requires a query"))
}
_ = mode
results, err := r.store.SearchMessages(r.ctx, store.SearchOptions{
opts := store.SearchOptions{
Query: fs.Arg(0),
GuildIDs: r.resolveSearchGuilds(*guildFlag, *guildsFlag),
Channel: *channel,
Author: *author,
Limit: *limit,
IncludeEmpty: *includeEmpty,
})
if err != nil {
return err
}
return r.print(results)
switch strings.ToLower(strings.TrimSpace(*mode)) {
case "", "fts":
results, err := r.store.SearchMessages(r.ctx, opts)
if err != nil {
return err
}
return r.print(results)
case "semantic":
results, err := r.searchMessagesSemantic(opts)
if err != nil {
return err
}
return r.print(results)
case "hybrid":
return fmt.Errorf("hybrid search is not implemented yet")
default:
return usageErr(fmt.Errorf("unsupported search mode %q", *mode))
}
}
func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) {
if !r.cfg.Search.Embeddings.Enabled {
return nil, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
}
providerFactory := r.newEmbed
if providerFactory == nil {
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
return embed.NewProvider(cfg)
}
}
provider, err := providerFactory(r.cfg.Search.Embeddings)
if err != nil {
return nil, fmt.Errorf("create embedding provider: %w", err)
}
batch, err := provider.Embed(r.ctx, []string{opts.Query})
if err != nil {
return nil, fmt.Errorf("embedding query failed: %w", err)
}
if len(batch.Vectors) != 1 {
return nil, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
}
queryVector := batch.Vectors[0]
dimensions := batch.Dimensions
if dimensions == 0 {
dimensions = len(queryVector)
}
return r.store.SearchMessagesSemantic(r.ctx, store.SemanticSearchOptions{
QueryVector: queryVector,
Provider: r.cfg.Search.Embeddings.Provider,
Model: r.cfg.Search.Embeddings.Model,
InputVersion: store.EmbeddingInputVersion,
Dimensions: dimensions,
GuildIDs: opts.GuildIDs,
Channel: opts.Channel,
Author: opts.Author,
Limit: opts.Limit,
IncludeEmpty: opts.IncludeEmpty,
})
}
func (r *runtime) runSQL(args []string) error {

View File

@ -5,7 +5,9 @@ import (
"database/sql"
"errors"
"fmt"
"math"
"os"
"sort"
"strings"
"time"
)
@ -19,6 +21,21 @@ const (
messageFTSHealthProbe = "__discrawl_probe__"
)
var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild")
type SemanticSearchOptions struct {
QueryVector []float32
Provider string
Model string
InputVersion string
Dimensions int
GuildIDs []string
Channel string
Author string
Limit int
IncludeEmpty bool
}
func (s *Store) GetSyncState(ctx context.Context, scope string) (string, error) {
var cursor sql.NullString
err := s.db.QueryRowContext(ctx, `select cursor from sync_state where scope = ?`, scope).Scan(&cursor)
@ -120,6 +137,172 @@ func (s *Store) SearchMessages(ctx context.Context, opts SearchOptions) ([]Searc
return out, rows.Err()
}
func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchOptions) ([]SearchResult, error) {
opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider))
opts.Model = strings.TrimSpace(opts.Model)
opts.InputVersion = strings.TrimSpace(opts.InputVersion)
if opts.InputVersion == "" {
opts.InputVersion = EmbeddingInputVersion
}
if opts.Limit <= 0 {
opts.Limit = 20
}
if len(opts.QueryVector) == 0 {
return nil, errors.New("semantic query embedding returned an empty vector")
}
if opts.Dimensions <= 0 {
opts.Dimensions = len(opts.QueryVector)
}
if len(opts.QueryVector) != opts.Dimensions {
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
}
queryNorm := vectorNorm(opts.QueryVector)
if queryNorm == 0 {
return nil, errors.New("semantic query embedding returned a zero vector")
}
clauses := []string{
"e.provider = ?",
"e.model = ?",
"e.input_version = ?",
"e.dimensions = ?",
}
args := []any{opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions}
if len(opts.GuildIDs) > 0 {
clauses = append(clauses, "m.guild_id in ("+placeholders(len(opts.GuildIDs))+")")
for _, guildID := range opts.GuildIDs {
args = append(args, guildID)
}
}
if strings.TrimSpace(opts.Channel) != "" {
clauses = append(clauses, "(m.channel_id = ? or c.name like ?)")
args = append(args, opts.Channel, "%"+opts.Channel+"%")
}
authorExpr := `coalesce(
json_extract(m.raw_json, '$.member.nick'),
json_extract(m.raw_json, '$.author.global_name'),
json_extract(m.raw_json, '$.author.username'),
''
)`
if strings.TrimSpace(opts.Author) != "" {
clauses = append(clauses, "(m.author_id = ? or "+authorExpr+" like ?)")
args = append(args, opts.Author, "%"+opts.Author+"%")
}
if !opts.IncludeEmpty {
clauses = append(clauses, "trim(coalesce(m.normalized_content, '')) <> ''")
}
args = append(args, searchCandidateLimit(opts.Limit))
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
rows, err := s.db.QueryContext(queryCtx, `
select
m.id,
m.guild_id,
m.channel_id,
coalesce(c.name, ''),
coalesce(m.author_id, ''),
`+authorExpr+`,
case
when trim(coalesce(m.content, '')) <> '' then m.content
else m.normalized_content
end,
m.created_at,
e.dimensions,
e.embedding_blob
from message_embeddings e
join messages m on m.id = e.message_id
left join channels c on c.id = m.channel_id
where `+strings.Join(clauses, " and ")+`
order by m.created_at desc, m.id desc
limit ?
`, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
type scoredResult struct {
result SearchResult
score float64
}
var scored []scoredResult
for rows.Next() {
var (
row SearchResult
created string
dimensions int
blob []byte
)
if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created, &dimensions, &blob); err != nil {
return nil, err
}
if dimensions != opts.Dimensions {
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
}
vector, err := DecodeEmbeddingVector(blob)
if err != nil {
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
}
if len(vector) != dimensions {
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions)
}
score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector)
if err != nil {
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
}
row.CreatedAt = parseTime(created)
scored = append(scored, scoredResult{result: row, score: score})
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(scored) == 0 {
compatible, err := s.hasCompatibleMessageEmbeddings(ctx, opts)
if err != nil {
return nil, err
}
if !compatible {
return nil, ErrNoCompatibleEmbeddings
}
return []SearchResult{}, nil
}
sort.SliceStable(scored, func(i, j int) bool {
if scored[i].score != scored[j].score {
return scored[i].score > scored[j].score
}
if !scored[i].result.CreatedAt.Equal(scored[j].result.CreatedAt) {
return scored[i].result.CreatedAt.After(scored[j].result.CreatedAt)
}
return scored[i].result.MessageID > scored[j].result.MessageID
})
if len(scored) > opts.Limit {
scored = scored[:opts.Limit]
}
out := make([]SearchResult, 0, len(scored))
for _, item := range scored {
out = append(out, item.result)
}
return out, nil
}
func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) {
queryCtx, cancel := withQueryTimeout(ctx)
defer cancel()
var exists int
err := s.db.QueryRowContext(queryCtx, `
select exists(
select 1
from message_embeddings
where provider = ?
and model = ?
and input_version = ?
and dimensions = ?
)
`, opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions).Scan(&exists)
return exists == 1, err
}
func (s *Store) CheckMessageFTS(ctx context.Context) error {
db, cleanup, err := s.openReadOnlyDB()
if err != nil {
@ -200,6 +383,29 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc
return out, rows.Err()
}
func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) {
if len(vector) != len(query) {
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query))
}
vectorNorm := vectorNorm(vector)
if vectorNorm == 0 {
return 0, errors.New("stored embedding vector is zero")
}
var dot float64
for i := range query {
dot += float64(query[i]) * float64(vector[i])
}
return dot / (queryNorm * vectorNorm), nil
}
func vectorNorm(vector []float32) float64 {
var sum float64
for _, value := range vector {
sum += float64(value) * float64(value)
}
return math.Sqrt(sum)
}
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
if strings.TrimSpace(query) != "" {
return s.searchMembers(ctx, guildID, query, limit)

View File

@ -185,6 +185,238 @@ func TestSearchMessagesPrefersRecentMessageIDs(t *testing.T) {
require.Contains(t, results[0].Content, "newest")
}
func TestSearchMessagesSemanticRanksAndFilters(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g2", Name: "Other", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c2", GuildID: "g1", Kind: "text", Name: "random", RawJSON: `{}`}))
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c3", GuildID: "g2", Kind: "text", Name: "other", RawJSON: `{}`}))
semanticMessages := []MessageRecord{
{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
MessageType: 0,
CreatedAt: base.Format(time.RFC3339Nano),
Content: "cats and databases",
NormalizedContent: "cats and databases",
RawJSON: `{"author":{"username":"Alice"}}`,
},
{
ID: "m2",
GuildID: "g1",
ChannelID: "c2",
ChannelName: "random",
AuthorID: "u2",
MessageType: 0,
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
Content: "cats but weaker",
NormalizedContent: "cats but weaker",
RawJSON: `{"author":{"username":"Bob"}}`,
},
{
ID: "m3",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u1",
MessageType: 0,
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
Content: "dogs",
NormalizedContent: "dogs",
RawJSON: `{"author":{"username":"Alice"}}`,
},
{
ID: "m4",
GuildID: "g2",
ChannelID: "c3",
ChannelName: "other",
AuthorID: "u3",
MessageType: 0,
CreatedAt: base.Add(3 * time.Minute).Format(time.RFC3339Nano),
Content: "other guild cats",
NormalizedContent: "other guild cats",
RawJSON: `{"author":{"username":"Carol"}}`,
},
{
ID: "m5",
GuildID: "g1",
ChannelID: "c1",
ChannelName: "general",
AuthorID: "u4",
MessageType: 0,
CreatedAt: base.Add(4 * time.Minute).Format(time.RFC3339Nano),
Content: "",
NormalizedContent: "",
RawJSON: `{"author":{"username":"Empty"}}`,
},
}
for _, message := range semanticMessages {
require.NoError(t, s.UpsertMessage(ctx, message))
}
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{1, 0}))
require.NoError(t, insertTestEmbedding(ctx, s, "m2", "ollama", "nomic-embed-text", []float32{0.9, 0.1}))
require.NoError(t, insertTestEmbedding(ctx, s, "m3", "ollama", "nomic-embed-text", []float32{0, 1}))
require.NoError(t, insertTestEmbedding(ctx, s, "m4", "ollama", "nomic-embed-text", []float32{1, 0}))
require.NoError(t, insertTestEmbedding(ctx, s, "m5", "ollama", "nomic-embed-text", []float32{1, 0}))
results, err := s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
GuildIDs: []string{"g1"},
Limit: 3,
})
require.NoError(t, err)
require.Equal(t, []string{"m1", "m2", "m3"}, searchResultIDs(results))
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
GuildIDs: []string{"g1"},
Channel: "general",
Author: "Alice",
Limit: 10,
})
require.NoError(t, err)
require.Equal(t, []string{"m1", "m3"}, searchResultIDs(results))
require.Equal(t, "Alice", results[0].AuthorName)
require.Equal(t, "general", results[0].ChannelName)
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
GuildIDs: []string{"g1"},
Channel: "general",
Limit: 10,
IncludeEmpty: true,
})
require.NoError(t, err)
require.Equal(t, []string{"m5", "m1", "m3"}, searchResultIDs(results))
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
GuildIDs: []string{"g1"},
Channel: "missing-channel",
Limit: 10,
})
require.NoError(t, err)
require.Empty(t, results)
}
func TestSearchMessagesSemanticErrors(t *testing.T) {
t.Parallel()
ctx := context.Background()
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
require.NoError(t, err)
defer func() { _ = s.Close() }()
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
ID: "m1",
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}))
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "missing-model",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
Limit: 10,
})
require.ErrorIs(t, err, ErrNoCompatibleEmbeddings)
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{0, 0},
Provider: "ollama",
Model: "nomic-embed-text",
Dimensions: 2,
Limit: 10,
})
require.ErrorContains(t, err, "zero vector")
require.NoError(t, insertTestEmbeddingBlob(ctx, s, "m1", "ollama", "nomic-embed-text", 2, []byte{0, 0, 0, 0}))
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
Limit: 10,
})
require.ErrorContains(t, err, "vector length mismatch")
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{0, 0}))
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
QueryVector: []float32{1, 0},
Provider: "ollama",
Model: "nomic-embed-text",
InputVersion: EmbeddingInputVersion,
Dimensions: 2,
Limit: 10,
})
require.ErrorContains(t, err, "stored embedding vector is zero")
}
func insertTestEmbedding(ctx context.Context, s *Store, messageID, provider, model string, vector []float32) error {
blob, err := EncodeEmbeddingVector(vector)
if err != nil {
return err
}
return insertTestEmbeddingBlob(ctx, s, messageID, provider, model, len(vector), blob)
}
func insertTestEmbeddingBlob(ctx context.Context, s *Store, messageID, provider, model string, dimensions int, blob []byte) error {
_, err := s.DB().ExecContext(ctx, `
insert into message_embeddings(
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
) values(?, ?, ?, ?, ?, ?, ?)
on conflict(message_id, provider, model, input_version) do update set
dimensions = excluded.dimensions,
embedding_blob = excluded.embedding_blob,
embedded_at = excluded.embedded_at
`, messageID, provider, model, EmbeddingInputVersion, dimensions, blob, time.Now().UTC().Format(timeLayout))
return err
}
func searchResultIDs(results []SearchResult) []string {
ids := make([]string, 0, len(results))
for _, result := range results {
ids = append(ids, result.MessageID)
}
return ids
}
func TestCheckMessageFTSProbe(t *testing.T) {
t.Parallel()