feat(search): add semantic message search
This commit is contained in:
parent
ed929a92eb
commit
3fc8defc35
@ -555,6 +555,139 @@ func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
|
||||
require.Contains(t, out.String(), "requeued=2")
|
||||
}
|
||||
|
||||
func TestSearchSemanticCommandUsesStoredEmbeddings(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
|
||||
var requests int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requests++
|
||||
require.Equal(t, "/embeddings", r.URL.Path)
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
require.Equal(t, "local-model", req.Model)
|
||||
require.Equal(t, []string{"cats"}, req.Input)
|
||||
_, _ = w.Write([]byte(`{"model":"local-model","data":[{"index":0,"embedding":[1,0]}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
cfg.Search.DefaultMode = "semantic"
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = server.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
AuthorName: "Alice",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "database migration discussion",
|
||||
NormalizedContent: "database migration discussion",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
}))
|
||||
require.NoError(t, s.UpsertMessage(ctx, store.MessageRecord{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u2",
|
||||
AuthorName: "Bob",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "cats in semantic search",
|
||||
NormalizedContent: "cats in semantic search",
|
||||
RawJSON: `{"author":{"username":"Bob"}}`,
|
||||
}))
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m1", "openai_compatible", "local-model", []float32{1, 0}))
|
||||
require.NoError(t, insertCLIEmbedding(ctx, s, "m2", "openai_compatible", "local-model", []float32{0.8, 0.2}))
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--limit", "1", "cats"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "database migration discussion")
|
||||
require.NotContains(t, out.String(), "cats in semantic search")
|
||||
require.Equal(t, 1, requests)
|
||||
|
||||
out.Reset()
|
||||
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "--channel", "general", "--author", "Alice", "cats"}, &out, &bytes.Buffer{}))
|
||||
require.Contains(t, out.String(), "database migration discussion")
|
||||
require.NotContains(t, out.String(), "cats in semantic search")
|
||||
require.Equal(t, 2, requests)
|
||||
}
|
||||
|
||||
func TestSearchSemanticCommandErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.toml")
|
||||
dbPath := filepath.Join(dir, "discrawl.db")
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.DBPath = dbPath
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
s, err := store.Open(ctx, dbPath)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.Close())
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "bogus", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 2, ExitCode(err))
|
||||
require.ErrorContains(t, err, `unsupported search mode "bogus"`)
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "hybrid", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 1, ExitCode(err))
|
||||
require.ErrorContains(t, err, "hybrid search is not implemented yet")
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 1, ExitCode(err))
|
||||
require.ErrorContains(t, err, "embeddings are disabled")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "nope", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
cfg.Search.Embeddings.Enabled = true
|
||||
cfg.Search.Embeddings.Provider = "openai_compatible"
|
||||
cfg.Search.Embeddings.Model = "local-model"
|
||||
cfg.Search.Embeddings.BaseURL = server.URL
|
||||
cfg.Search.Embeddings.APIKeyEnv = ""
|
||||
require.NoError(t, config.Write(cfgPath, cfg))
|
||||
|
||||
err = Run(ctx, []string{"--config", cfgPath, "search", "--mode", "semantic", "cats"}, &bytes.Buffer{}, &bytes.Buffer{})
|
||||
require.Equal(t, 1, ExitCode(err))
|
||||
require.ErrorContains(t, err, "embedding query failed")
|
||||
}
|
||||
|
||||
func insertCLIEmbedding(ctx context.Context, s *store.Store, messageID, provider, model string, vector []float32) error {
|
||||
blob, err := store.EncodeEmbeddingVector(vector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.DB().ExecContext(ctx, `
|
||||
insert into message_embeddings(
|
||||
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
) values(?, ?, ?, ?, ?, ?, ?)
|
||||
`, messageID, provider, model, store.EmbeddingInputVersion, len(vector), blob, time.Now().UTC().Format(time.RFC3339Nano))
|
||||
return err
|
||||
}
|
||||
|
||||
type fakeDiscordClient struct {
|
||||
guilds []*discordgo.UserGuild
|
||||
self *discordgo.User
|
||||
|
||||
@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/steipete/discrawl/internal/config"
|
||||
"github.com/steipete/discrawl/internal/embed"
|
||||
"github.com/steipete/discrawl/internal/store"
|
||||
)
|
||||
|
||||
@ -27,19 +29,72 @@ func (r *runtime) runSearch(args []string) error {
|
||||
if fs.NArg() != 1 {
|
||||
return usageErr(fmt.Errorf("search requires a query"))
|
||||
}
|
||||
_ = mode
|
||||
results, err := r.store.SearchMessages(r.ctx, store.SearchOptions{
|
||||
opts := store.SearchOptions{
|
||||
Query: fs.Arg(0),
|
||||
GuildIDs: r.resolveSearchGuilds(*guildFlag, *guildsFlag),
|
||||
Channel: *channel,
|
||||
Author: *author,
|
||||
Limit: *limit,
|
||||
IncludeEmpty: *includeEmpty,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.print(results)
|
||||
switch strings.ToLower(strings.TrimSpace(*mode)) {
|
||||
case "", "fts":
|
||||
results, err := r.store.SearchMessages(r.ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.print(results)
|
||||
case "semantic":
|
||||
results, err := r.searchMessagesSemantic(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.print(results)
|
||||
case "hybrid":
|
||||
return fmt.Errorf("hybrid search is not implemented yet")
|
||||
default:
|
||||
return usageErr(fmt.Errorf("unsupported search mode %q", *mode))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *runtime) searchMessagesSemantic(opts store.SearchOptions) ([]store.SearchResult, error) {
|
||||
if !r.cfg.Search.Embeddings.Enabled {
|
||||
return nil, fmt.Errorf("embeddings are disabled; enable [search.embeddings] first")
|
||||
}
|
||||
providerFactory := r.newEmbed
|
||||
if providerFactory == nil {
|
||||
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
|
||||
return embed.NewProvider(cfg)
|
||||
}
|
||||
}
|
||||
provider, err := providerFactory(r.cfg.Search.Embeddings)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding provider: %w", err)
|
||||
}
|
||||
batch, err := provider.Embed(r.ctx, []string{opts.Query})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding query failed: %w", err)
|
||||
}
|
||||
if len(batch.Vectors) != 1 {
|
||||
return nil, fmt.Errorf("embedding query returned %d vectors for 1 input", len(batch.Vectors))
|
||||
}
|
||||
queryVector := batch.Vectors[0]
|
||||
dimensions := batch.Dimensions
|
||||
if dimensions == 0 {
|
||||
dimensions = len(queryVector)
|
||||
}
|
||||
return r.store.SearchMessagesSemantic(r.ctx, store.SemanticSearchOptions{
|
||||
QueryVector: queryVector,
|
||||
Provider: r.cfg.Search.Embeddings.Provider,
|
||||
Model: r.cfg.Search.Embeddings.Model,
|
||||
InputVersion: store.EmbeddingInputVersion,
|
||||
Dimensions: dimensions,
|
||||
GuildIDs: opts.GuildIDs,
|
||||
Channel: opts.Channel,
|
||||
Author: opts.Author,
|
||||
Limit: opts.Limit,
|
||||
IncludeEmpty: opts.IncludeEmpty,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *runtime) runSQL(args []string) error {
|
||||
|
||||
@ -5,7 +5,9 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -19,6 +21,21 @@ const (
|
||||
messageFTSHealthProbe = "__discrawl_probe__"
|
||||
)
|
||||
|
||||
var ErrNoCompatibleEmbeddings = errors.New("no compatible message embeddings for provider/model/input version; run discrawl embed --rebuild")
|
||||
|
||||
type SemanticSearchOptions struct {
|
||||
QueryVector []float32
|
||||
Provider string
|
||||
Model string
|
||||
InputVersion string
|
||||
Dimensions int
|
||||
GuildIDs []string
|
||||
Channel string
|
||||
Author string
|
||||
Limit int
|
||||
IncludeEmpty bool
|
||||
}
|
||||
|
||||
func (s *Store) GetSyncState(ctx context.Context, scope string) (string, error) {
|
||||
var cursor sql.NullString
|
||||
err := s.db.QueryRowContext(ctx, `select cursor from sync_state where scope = ?`, scope).Scan(&cursor)
|
||||
@ -120,6 +137,172 @@ func (s *Store) SearchMessages(ctx context.Context, opts SearchOptions) ([]Searc
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) SearchMessagesSemantic(ctx context.Context, opts SemanticSearchOptions) ([]SearchResult, error) {
|
||||
opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider))
|
||||
opts.Model = strings.TrimSpace(opts.Model)
|
||||
opts.InputVersion = strings.TrimSpace(opts.InputVersion)
|
||||
if opts.InputVersion == "" {
|
||||
opts.InputVersion = EmbeddingInputVersion
|
||||
}
|
||||
if opts.Limit <= 0 {
|
||||
opts.Limit = 20
|
||||
}
|
||||
if len(opts.QueryVector) == 0 {
|
||||
return nil, errors.New("semantic query embedding returned an empty vector")
|
||||
}
|
||||
if opts.Dimensions <= 0 {
|
||||
opts.Dimensions = len(opts.QueryVector)
|
||||
}
|
||||
if len(opts.QueryVector) != opts.Dimensions {
|
||||
return nil, fmt.Errorf("semantic query embedding dimensions mismatch: got %d want %d", len(opts.QueryVector), opts.Dimensions)
|
||||
}
|
||||
queryNorm := vectorNorm(opts.QueryVector)
|
||||
if queryNorm == 0 {
|
||||
return nil, errors.New("semantic query embedding returned a zero vector")
|
||||
}
|
||||
|
||||
clauses := []string{
|
||||
"e.provider = ?",
|
||||
"e.model = ?",
|
||||
"e.input_version = ?",
|
||||
"e.dimensions = ?",
|
||||
}
|
||||
args := []any{opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions}
|
||||
if len(opts.GuildIDs) > 0 {
|
||||
clauses = append(clauses, "m.guild_id in ("+placeholders(len(opts.GuildIDs))+")")
|
||||
for _, guildID := range opts.GuildIDs {
|
||||
args = append(args, guildID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(opts.Channel) != "" {
|
||||
clauses = append(clauses, "(m.channel_id = ? or c.name like ?)")
|
||||
args = append(args, opts.Channel, "%"+opts.Channel+"%")
|
||||
}
|
||||
authorExpr := `coalesce(
|
||||
json_extract(m.raw_json, '$.member.nick'),
|
||||
json_extract(m.raw_json, '$.author.global_name'),
|
||||
json_extract(m.raw_json, '$.author.username'),
|
||||
''
|
||||
)`
|
||||
if strings.TrimSpace(opts.Author) != "" {
|
||||
clauses = append(clauses, "(m.author_id = ? or "+authorExpr+" like ?)")
|
||||
args = append(args, opts.Author, "%"+opts.Author+"%")
|
||||
}
|
||||
if !opts.IncludeEmpty {
|
||||
clauses = append(clauses, "trim(coalesce(m.normalized_content, '')) <> ''")
|
||||
}
|
||||
args = append(args, searchCandidateLimit(opts.Limit))
|
||||
|
||||
queryCtx, cancel := withQueryTimeout(ctx)
|
||||
defer cancel()
|
||||
rows, err := s.db.QueryContext(queryCtx, `
|
||||
select
|
||||
m.id,
|
||||
m.guild_id,
|
||||
m.channel_id,
|
||||
coalesce(c.name, ''),
|
||||
coalesce(m.author_id, ''),
|
||||
`+authorExpr+`,
|
||||
case
|
||||
when trim(coalesce(m.content, '')) <> '' then m.content
|
||||
else m.normalized_content
|
||||
end,
|
||||
m.created_at,
|
||||
e.dimensions,
|
||||
e.embedding_blob
|
||||
from message_embeddings e
|
||||
join messages m on m.id = e.message_id
|
||||
left join channels c on c.id = m.channel_id
|
||||
where `+strings.Join(clauses, " and ")+`
|
||||
order by m.created_at desc, m.id desc
|
||||
limit ?
|
||||
`, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
type scoredResult struct {
|
||||
result SearchResult
|
||||
score float64
|
||||
}
|
||||
var scored []scoredResult
|
||||
for rows.Next() {
|
||||
var (
|
||||
row SearchResult
|
||||
created string
|
||||
dimensions int
|
||||
blob []byte
|
||||
)
|
||||
if err := rows.Scan(&row.MessageID, &row.GuildID, &row.ChannelID, &row.ChannelName, &row.AuthorID, &row.AuthorName, &row.Content, &created, &dimensions, &blob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dimensions != opts.Dimensions {
|
||||
return nil, fmt.Errorf("stored embedding dimensions mismatch for message %s: got %d want %d", row.MessageID, dimensions, opts.Dimensions)
|
||||
}
|
||||
vector, err := DecodeEmbeddingVector(blob)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode embedding for message %s: %w", row.MessageID, err)
|
||||
}
|
||||
if len(vector) != dimensions {
|
||||
return nil, fmt.Errorf("stored embedding vector length mismatch for message %s: got %d want %d", row.MessageID, len(vector), dimensions)
|
||||
}
|
||||
score, err := cosineSimilarity(opts.QueryVector, queryNorm, vector)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("score embedding for message %s: %w", row.MessageID, err)
|
||||
}
|
||||
row.CreatedAt = parseTime(created)
|
||||
scored = append(scored, scoredResult{result: row, score: score})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(scored) == 0 {
|
||||
compatible, err := s.hasCompatibleMessageEmbeddings(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !compatible {
|
||||
return nil, ErrNoCompatibleEmbeddings
|
||||
}
|
||||
return []SearchResult{}, nil
|
||||
}
|
||||
sort.SliceStable(scored, func(i, j int) bool {
|
||||
if scored[i].score != scored[j].score {
|
||||
return scored[i].score > scored[j].score
|
||||
}
|
||||
if !scored[i].result.CreatedAt.Equal(scored[j].result.CreatedAt) {
|
||||
return scored[i].result.CreatedAt.After(scored[j].result.CreatedAt)
|
||||
}
|
||||
return scored[i].result.MessageID > scored[j].result.MessageID
|
||||
})
|
||||
if len(scored) > opts.Limit {
|
||||
scored = scored[:opts.Limit]
|
||||
}
|
||||
out := make([]SearchResult, 0, len(scored))
|
||||
for _, item := range scored {
|
||||
out = append(out, item.result)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) hasCompatibleMessageEmbeddings(ctx context.Context, opts SemanticSearchOptions) (bool, error) {
|
||||
queryCtx, cancel := withQueryTimeout(ctx)
|
||||
defer cancel()
|
||||
var exists int
|
||||
err := s.db.QueryRowContext(queryCtx, `
|
||||
select exists(
|
||||
select 1
|
||||
from message_embeddings
|
||||
where provider = ?
|
||||
and model = ?
|
||||
and input_version = ?
|
||||
and dimensions = ?
|
||||
)
|
||||
`, opts.Provider, opts.Model, opts.InputVersion, opts.Dimensions).Scan(&exists)
|
||||
return exists == 1, err
|
||||
}
|
||||
|
||||
func (s *Store) CheckMessageFTS(ctx context.Context) error {
|
||||
db, cleanup, err := s.openReadOnlyDB()
|
||||
if err != nil {
|
||||
@ -200,6 +383,29 @@ func (s *Store) searchFallback(ctx context.Context, opts SearchOptions) ([]Searc
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func cosineSimilarity(query []float32, queryNorm float64, vector []float32) (float64, error) {
|
||||
if len(vector) != len(query) {
|
||||
return 0, fmt.Errorf("dimensions mismatch: got %d want %d", len(vector), len(query))
|
||||
}
|
||||
vectorNorm := vectorNorm(vector)
|
||||
if vectorNorm == 0 {
|
||||
return 0, errors.New("stored embedding vector is zero")
|
||||
}
|
||||
var dot float64
|
||||
for i := range query {
|
||||
dot += float64(query[i]) * float64(vector[i])
|
||||
}
|
||||
return dot / (queryNorm * vectorNorm), nil
|
||||
}
|
||||
|
||||
func vectorNorm(vector []float32) float64 {
|
||||
var sum float64
|
||||
for _, value := range vector {
|
||||
sum += float64(value) * float64(value)
|
||||
}
|
||||
return math.Sqrt(sum)
|
||||
}
|
||||
|
||||
func (s *Store) Members(ctx context.Context, guildID, query string, limit int) ([]MemberRow, error) {
|
||||
if strings.TrimSpace(query) != "" {
|
||||
return s.searchMembers(ctx, guildID, query, limit)
|
||||
|
||||
@ -185,6 +185,238 @@ func TestSearchMessagesPrefersRecentMessageIDs(t *testing.T) {
|
||||
require.Contains(t, results[0].Content, "newest")
|
||||
}
|
||||
|
||||
func TestSearchMessagesSemanticRanksAndFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
base := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g2", Name: "Other", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c2", GuildID: "g1", Kind: "text", Name: "random", RawJSON: `{}`}))
|
||||
require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "c3", GuildID: "g2", Kind: "text", Name: "other", RawJSON: `{}`}))
|
||||
|
||||
semanticMessages := []MessageRecord{
|
||||
{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Format(time.RFC3339Nano),
|
||||
Content: "cats and databases",
|
||||
NormalizedContent: "cats and databases",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m2",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c2",
|
||||
ChannelName: "random",
|
||||
AuthorID: "u2",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "cats but weaker",
|
||||
NormalizedContent: "cats but weaker",
|
||||
RawJSON: `{"author":{"username":"Bob"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m3",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u1",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(2 * time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "dogs",
|
||||
NormalizedContent: "dogs",
|
||||
RawJSON: `{"author":{"username":"Alice"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m4",
|
||||
GuildID: "g2",
|
||||
ChannelID: "c3",
|
||||
ChannelName: "other",
|
||||
AuthorID: "u3",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(3 * time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "other guild cats",
|
||||
NormalizedContent: "other guild cats",
|
||||
RawJSON: `{"author":{"username":"Carol"}}`,
|
||||
},
|
||||
{
|
||||
ID: "m5",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
ChannelName: "general",
|
||||
AuthorID: "u4",
|
||||
MessageType: 0,
|
||||
CreatedAt: base.Add(4 * time.Minute).Format(time.RFC3339Nano),
|
||||
Content: "",
|
||||
NormalizedContent: "",
|
||||
RawJSON: `{"author":{"username":"Empty"}}`,
|
||||
},
|
||||
}
|
||||
for _, message := range semanticMessages {
|
||||
require.NoError(t, s.UpsertMessage(ctx, message))
|
||||
}
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{1, 0}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m2", "ollama", "nomic-embed-text", []float32{0.9, 0.1}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m3", "ollama", "nomic-embed-text", []float32{0, 1}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m4", "ollama", "nomic-embed-text", []float32{1, 0}))
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m5", "ollama", "nomic-embed-text", []float32{1, 0}))
|
||||
|
||||
results, err := s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
GuildIDs: []string{"g1"},
|
||||
Limit: 3,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"m1", "m2", "m3"}, searchResultIDs(results))
|
||||
|
||||
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
GuildIDs: []string{"g1"},
|
||||
Channel: "general",
|
||||
Author: "Alice",
|
||||
Limit: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"m1", "m3"}, searchResultIDs(results))
|
||||
require.Equal(t, "Alice", results[0].AuthorName)
|
||||
require.Equal(t, "general", results[0].ChannelName)
|
||||
|
||||
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
GuildIDs: []string{"g1"},
|
||||
Channel: "general",
|
||||
Limit: 10,
|
||||
IncludeEmpty: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"m5", "m1", "m3"}, searchResultIDs(results))
|
||||
|
||||
results, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
GuildIDs: []string{"g1"},
|
||||
Channel: "missing-channel",
|
||||
Limit: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, results)
|
||||
}
|
||||
|
||||
func TestSearchMessagesSemanticErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db"))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
require.NoError(t, s.UpsertMessage(ctx, MessageRecord{
|
||||
ID: "m1",
|
||||
GuildID: "g1",
|
||||
ChannelID: "c1",
|
||||
MessageType: 0,
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Content: "hello",
|
||||
NormalizedContent: "hello",
|
||||
RawJSON: `{}`,
|
||||
}))
|
||||
|
||||
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "missing-model",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
Limit: 10,
|
||||
})
|
||||
require.ErrorIs(t, err, ErrNoCompatibleEmbeddings)
|
||||
|
||||
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{0, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
Dimensions: 2,
|
||||
Limit: 10,
|
||||
})
|
||||
require.ErrorContains(t, err, "zero vector")
|
||||
|
||||
require.NoError(t, insertTestEmbeddingBlob(ctx, s, "m1", "ollama", "nomic-embed-text", 2, []byte{0, 0, 0, 0}))
|
||||
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
Limit: 10,
|
||||
})
|
||||
require.ErrorContains(t, err, "vector length mismatch")
|
||||
|
||||
require.NoError(t, insertTestEmbedding(ctx, s, "m1", "ollama", "nomic-embed-text", []float32{0, 0}))
|
||||
_, err = s.SearchMessagesSemantic(ctx, SemanticSearchOptions{
|
||||
QueryVector: []float32{1, 0},
|
||||
Provider: "ollama",
|
||||
Model: "nomic-embed-text",
|
||||
InputVersion: EmbeddingInputVersion,
|
||||
Dimensions: 2,
|
||||
Limit: 10,
|
||||
})
|
||||
require.ErrorContains(t, err, "stored embedding vector is zero")
|
||||
}
|
||||
|
||||
func insertTestEmbedding(ctx context.Context, s *Store, messageID, provider, model string, vector []float32) error {
|
||||
blob, err := EncodeEmbeddingVector(vector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return insertTestEmbeddingBlob(ctx, s, messageID, provider, model, len(vector), blob)
|
||||
}
|
||||
|
||||
func insertTestEmbeddingBlob(ctx context.Context, s *Store, messageID, provider, model string, dimensions int, blob []byte) error {
|
||||
_, err := s.DB().ExecContext(ctx, `
|
||||
insert into message_embeddings(
|
||||
message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at
|
||||
) values(?, ?, ?, ?, ?, ?, ?)
|
||||
on conflict(message_id, provider, model, input_version) do update set
|
||||
dimensions = excluded.dimensions,
|
||||
embedding_blob = excluded.embedding_blob,
|
||||
embedded_at = excluded.embedded_at
|
||||
`, messageID, provider, model, EmbeddingInputVersion, dimensions, blob, time.Now().UTC().Format(timeLayout))
|
||||
return err
|
||||
}
|
||||
|
||||
func searchResultIDs(results []SearchResult) []string {
|
||||
ids := make([]string, 0, len(results))
|
||||
for _, result := range results {
|
||||
ids = append(ids, result.MessageID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestCheckMessageFTSProbe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user